Using enable_if and template specializations to deal with parser.get<vector<T>>(..)

This commit is contained in:
Pranav Srinivas Kumar 2019-03-30 13:47:59 -04:00
parent 16e3f58ce9
commit 717e6aafdb
2 changed files with 22 additions and 13 deletions

View File

@ -4,9 +4,16 @@
#include <functional> #include <functional>
#include <any> #include <any>
#include <memory> #include <memory>
#include <type_traits> // C++0x
namespace argparse { namespace argparse {
template<typename Test, template<typename...> class Ref>
struct is_specialization : std::false_type {};
template<template<typename...> class Ref, typename... Args>
struct is_specialization<Ref<Args...>, Ref> : std::true_type {};
template <class KeyType, class ElementType> template <class KeyType, class ElementType>
bool upsert(std::map<KeyType, ElementType>& aMap, KeyType const& aKey, ElementType const& aNewValue) { bool upsert(std::map<KeyType, ElementType>& aMap, KeyType const& aKey, ElementType const& aNewValue) {
typedef typename std::map<KeyType, ElementType>::iterator Iterator; typedef typename std::map<KeyType, ElementType>::iterator Iterator;
@ -84,24 +91,24 @@ struct Argument {
} }
template <typename T> template <typename T>
std::vector<T> get_list() { T get_vector() {
std::vector<T> tResult; T tResult;
if (mValues.size() == 0) { if (mValues.size() == 0) {
if (mDefaultValue != nullptr) { if (mDefaultValue != nullptr) {
std::any tDefaultValueLambdaResult = mDefaultValue(); std::any tDefaultValueLambdaResult = mDefaultValue();
std::vector<T> tDefaultValues = std::any_cast<std::vector<T>>(tDefaultValueLambdaResult); T tDefaultValues = std::any_cast<T>(tDefaultValueLambdaResult);
for (size_t i = 0; i < tDefaultValues.size(); i++) { for (size_t i = 0; i < tDefaultValues.size(); i++) {
tResult.push_back(std::any_cast<T>(tDefaultValues[i])); tResult.push_back(std::any_cast<typename T::value_type>(tDefaultValues[i]));
} }
return tResult; return tResult;
} }
else else
return std::vector<T>(); return T();
} }
else { else {
if (mRawValues.size() > 0) { if (mRawValues.size() > 0) {
for (size_t i = 0; i < mValues.size(); i++) { for (size_t i = 0; i < mValues.size(); i++) {
tResult.push_back(std::any_cast<T>(mValues[i])); tResult.push_back(std::any_cast<typename T::value_type>(mValues[i]));
} }
return tResult; return tResult;
} }
@ -110,12 +117,12 @@ struct Argument {
std::any tDefaultValueLambdaResult = mDefaultValue(); std::any tDefaultValueLambdaResult = mDefaultValue();
std::vector<T> tDefaultValues = std::any_cast<std::vector<T>>(tDefaultValueLambdaResult); std::vector<T> tDefaultValues = std::any_cast<std::vector<T>>(tDefaultValueLambdaResult);
for (size_t i = 0; i < tDefaultValues.size(); i++) { for (size_t i = 0; i < tDefaultValues.size(); i++) {
tResult.push_back(std::any_cast<T>(tDefaultValues[i])); tResult.push_back(std::any_cast<typename T::value_type>(tDefaultValues[i]));
} }
return tResult; return tResult;
} }
else else
return std::vector<T>(); return T();
} }
} }
} }
@ -175,7 +182,8 @@ class ArgumentParser {
} }
template <typename T = std::string> template <typename T = std::string>
T get(const char * aArgumentName) { typename std::enable_if<is_specialization<T, std::vector>::value == false, T>::type
get(const char * aArgumentName) {
std::map<std::string, std::shared_ptr<Argument>>::iterator tIterator = mArgumentMap.find(aArgumentName); std::map<std::string, std::shared_ptr<Argument>>::iterator tIterator = mArgumentMap.find(aArgumentName);
if (tIterator != mArgumentMap.end()) { if (tIterator != mArgumentMap.end()) {
return tIterator->second->get<T>(); return tIterator->second->get<T>();
@ -184,12 +192,13 @@ class ArgumentParser {
} }
template <typename T> template <typename T>
std::vector<T> get_list(const char * aArgumentName) { typename std::enable_if<is_specialization<T, std::vector>::value, T>::type
get(const char * aArgumentName) {
std::map<std::string, std::shared_ptr<Argument>>::iterator tIterator = mArgumentMap.find(aArgumentName); std::map<std::string, std::shared_ptr<Argument>>::iterator tIterator = mArgumentMap.find(aArgumentName);
if (tIterator != mArgumentMap.end()) { if (tIterator != mArgumentMap.end()) {
return tIterator->second->get_list<T>(); return tIterator->second->get_vector<T>();
} }
return std::vector<T>(); return T();
} }
private: private:

View File

@ -27,7 +27,7 @@ int main(int argc, char * argv[]) {
auto config_file = program.get<std::string>("--config"); auto config_file = program.get<std::string>("--config");
auto num_iters = program.get<int>("-n"); auto num_iters = program.get<int>("-n");
auto verbose = program.get<bool>("-v"); auto verbose = program.get<bool>("-v");
auto test_inputs = program.get_list<int>("--test_inputs"); auto test_inputs = program.get<std::vector<int>>("--test_inputs");
std::cout << config_file << std::endl; std::cout << config_file << std::endl;
std::cout << num_iters << std::endl; std::cout << num_iters << std::endl;