Merge pull request #19 from svanveen/fix/parse-args

Simplify structure of parse_args_internal
This commit is contained in:
Pranav Srinivas Kumar 2019-05-24 15:01:33 -04:00 committed by GitHub
commit 7c5ee10205
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 185 deletions

View File

@ -73,12 +73,6 @@ using enable_if_container = std::enable_if_t<is_container_v<T>, T>;
template <typename T>
using enable_if_not_container = std::enable_if_t<!is_container_v<T>, T>;
// Check if string (haystack) starts with a substring (needle)
bool starts_with(const std::string& haystack, const std::string& needle) {
return needle.length() <= haystack.length()
&& std::equal(needle.begin(), needle.end(), haystack.begin());
}
}
class Argument {
@ -118,6 +112,33 @@ public:
return *this;
}
template <typename Iterator>
Iterator consume(Iterator start, Iterator end, std::string usedName = {}) {
if (mIsUsed) {
throw std::runtime_error("Duplicate argument");
}
mIsUsed = true;
mUsedName = std::move(usedName);
if (mNumArgs == 0) {
mValues.emplace_back(mImplicitValue);
return start;
}
else if (mNumArgs <= std::distance(start, end)) {
end = std::next(start, mNumArgs);
if (std::any_of(start, end, Argument::is_optional)) {
throw std::runtime_error("optional argument in parameter sequence");
}
std::transform(start, end, std::back_inserter(mValues), mAction);
return end;
}
else if (mDefaultValue.has_value()) {
return start;
}
else {
throw std::runtime_error("Too few arguments");
}
}
/*
* @throws std::runtime_error if argument values are not valid
*/
@ -134,7 +155,7 @@ public:
}
}
else {
if (mValues.size() != mNumArgs) {
if (mValues.size() != mNumArgs && !mDefaultValue.has_value()) {
std::stringstream stream;
stream << "error: " << mUsedName << ": expected " << mNumArgs << " argument(s). "
<< mValues.size() << " provided.\n" << std::endl;
@ -180,7 +201,11 @@ public:
private:
// If an argument starts with "-" or "--", then it's optional
static bool is_optional(const std::string& aName) {
return (starts_with(aName, "--") || starts_with(aName, "-"));
return (!aName.empty() && aName[0] == '-');
}
static bool is_positional(const std::string& aName) {
return !is_optional(aName);
}
/*
@ -302,8 +327,9 @@ class ArgumentParser {
* @throws std::runtime_error in case of any invalid argument
*/
void parse_args(int argc, const char * const argv[]) {
parse_args_internal(argc, argv);
parse_args_validate();
std::vector<std::string> arguments;
std::copy(argv, argv + argc, std::back_inserter(arguments));
parse_args(arguments);
}
/* Getter enabled for all template types other than std::vector and std::list
@ -406,132 +432,45 @@ class ArgumentParser {
* @throws std::runtime_error in case of any invalid argument
*/
void parse_args_internal(const std::vector<std::string>& aArguments) {
std::vector<char*> argv;
for (const auto& arg : aArguments)
argv.emplace_back(const_cast<char*>(arg.data()));
argv.emplace_back(nullptr);
return parse_args_internal(int(argv.size()) - 1, argv.data());
if (mProgramName.empty() && !aArguments.empty()) {
mProgramName = aArguments.front();
}
/*
* @throws std::runtime_error in case of any invalid argument
*/
void parse_args_internal(int argc, const char * const argv[]) {
if (mProgramName.empty() && argc > 0)
mProgramName = argv[0];
for (int i = 1; i < argc; i++) {
auto tCurrentArgument = std::string(argv[i]);
auto end = std::end(aArguments);
auto positionalArgumentIt = std::begin(mPositionalArguments);
for (auto it = std::next(std::begin(aArguments)); it != end;) {
const auto& tCurrentArgument = *it;
if (tCurrentArgument == Argument::mHelpOption || tCurrentArgument == Argument::mHelpOptionLong) {
throw std::runtime_error("help called");
}
auto tIterator = mArgumentMap.find(argv[i]);
if (tIterator != mArgumentMap.end()) {
// Start parsing optional argument
if (Argument::is_positional(tCurrentArgument)) {
if (positionalArgumentIt == std::end(mPositionalArguments)) {
throw std::runtime_error("Maximum number of positional arguments exceeded");
}
auto tArgument = *(positionalArgumentIt++);
it = tArgument->consume(it, end);
}
else if (auto tIterator = mArgumentMap.find(tCurrentArgument); tIterator != mArgumentMap.end()) {
auto tArgument = tIterator->second;
tArgument->mUsedName = tCurrentArgument;
tArgument->mIsUsed = true;
auto tCount = tArgument->mNumArgs;
// Check to see if implicit value should be used
// Two cases to handle here:
// (1) User has explicitly programmed nargs to be 0
// (2) User has provided an implicit value, which also sets nargs to 0
if (tCount == 0) {
// Use implicit value for this optional argument
tArgument->mValues.emplace_back(tArgument->mImplicitValue);
tArgument->mRawValues.emplace_back();
tCount = 0;
it = tArgument->consume(std::next(it), end, tCurrentArgument);
}
while (tCount > 0) {
i = i + 1;
if (i < argc) {
tArgument->mUsedName = tCurrentArgument;
tArgument->mRawValues.emplace_back(argv[i]);
if (tArgument->mAction != nullptr)
tArgument->mValues.emplace_back(tArgument->mAction(argv[i]));
else {
if (tArgument->mDefaultValue.has_value())
tArgument->mValues.emplace_back(tArgument->mDefaultValue);
else
tArgument->mValues.emplace_back(std::string(argv[i]));
}
}
tCount -= 1;
}
}
else {
if (Argument::is_optional(argv[i])) {
// This is possibly a compound optional argument
// Example: We have three optional arguments -a, -u and -x
// The user provides ./main -aux ...
// Here -aux is a compound optional argument
std::string tCompoundArgument = std::string(argv[i]);
if (tCompoundArgument.size() > 1 && tCompoundArgument[0] == '-' && tCompoundArgument[1] != '-') {
else if (const auto& tCompoundArgument = tCurrentArgument;
tCompoundArgument.size() > 1 &&
tCompoundArgument[0] == '-' &&
tCompoundArgument[1] != '-') {
++it;
for (size_t j = 1; j < tCompoundArgument.size(); j++) {
std::string tArgument(1, tCompoundArgument[j]);
size_t tNumArgs = 0;
tIterator = mArgumentMap.find("-" + tArgument);
if (tIterator != mArgumentMap.end()) {
auto tArgumentObject = tIterator->second;
tNumArgs = tArgumentObject->mNumArgs;
std::vector<std::string> tArgumentsForRecursiveParsing = {"", "-" + tArgument};
while (tNumArgs > 0 && i < argc) {
i += 1;
if (i < argc) {
tArgumentsForRecursiveParsing.emplace_back(argv[i]);
tNumArgs -= 1;
}
}
parse_args_internal(tArgumentsForRecursiveParsing);
auto tCurrentArgument = std::string{'-', tCompoundArgument[j]};
if (auto tIterator = mArgumentMap.find(tCurrentArgument); tIterator != mArgumentMap.end()) {
auto tArgument = tIterator->second;
it = tArgument->consume(it, end, tCurrentArgument);
}
else {
if (!tArgument.empty() && tArgument[0] == '-')
std::cout << "warning: unrecognized optional argument " << tArgument
<< std::endl;
else
std::cout << "warning: unrecognized optional argument -" << tArgument
<< std::endl;
throw std::runtime_error("Unknown argument");
}
}
}
else {
std::cout << "warning: unrecognized optional argument " << tCompoundArgument << std::endl;
}
}
else {
// This is a positional argument.
// Parse and save into mPositionalArguments vector
if (mNextPositionalArgument >= mPositionalArguments.size()) {
std::stringstream stream;
stream << "error: unexpected positional argument " << argv[i] << std::endl;
throw std::runtime_error(stream.str());
}
auto tArgument = mPositionalArguments[mNextPositionalArgument];
auto tCount = tArgument->mNumArgs - tArgument->mRawValues.size();
while (tCount > 0) {
tIterator = mArgumentMap.find(argv[i]);
if (tIterator != mArgumentMap.end() || Argument::is_optional(argv[i])) {
i = i - 1;
break;
}
if (i < argc) {
tArgument->mUsedName = tCurrentArgument;
tArgument->mRawValues.emplace_back(argv[i]);
if (tArgument->mAction != nullptr)
tArgument->mValues.emplace_back(tArgument->mAction(argv[i]));
else {
if (tArgument->mDefaultValue.has_value())
tArgument->mValues.emplace_back(tArgument->mDefaultValue);
else
tArgument->mValues.emplace_back(std::string(argv[i]));
}
}
tCount -= 1;
if (tCount > 0) i += 1;
}
if (tCount == 0)
mNextPositionalArgument += 1;
}
throw std::runtime_error("Unknown argument");
}
}
}
@ -540,26 +479,20 @@ class ArgumentParser {
* @throws std::runtime_error in case of any invalid argument
*/
void parse_args_validate() {
try {
// Check if all positional arguments are parsed
std::for_each(std::begin(mPositionalArguments),
std::end(mPositionalArguments),
std::mem_fn(&Argument::validate));
// Check if all user-provided optional argument values are parsed correctly
std::for_each(std::begin(mOptionalArguments),
std::end(mOptionalArguments),
std::mem_fn(&Argument::validate));
} catch (const std::runtime_error& err) {
throw err;
}
// Check if all arguments are parsed
std::for_each(std::begin(mArgumentMap), std::end(mArgumentMap), [](const auto& argPair) {
const auto& [key, arg] = argPair;
arg->validate();
});
}
// Used by print_help.
size_t get_length_of_longest_argument(const std::vector<std::shared_ptr<Argument>>& aArguments) {
if (aArguments.empty())
size_t get_length_of_longest_argument() {
if (mArgumentMap.empty())
return 0;
std::vector<size_t> argumentLengths(aArguments.size());
std::transform(std::begin(aArguments), std::end(aArguments), std::begin(argumentLengths), [](const auto& arg) {
std::vector<size_t> argumentLengths(mArgumentMap.size());
std::transform(std::begin(mArgumentMap), std::end(mArgumentMap), std::begin(argumentLengths), [](const auto& argPair) {
const auto& [key, arg] = argPair;
const auto& names = arg->mNames;
auto maxLength = std::accumulate(std::begin(names), std::end(names), std::string::size_type{0}, [](const auto& sum, const auto& s) {
return sum + s.size() + 2; // +2 for ", "
@ -569,19 +502,10 @@ class ArgumentParser {
return *std::max_element(std::begin(argumentLengths), std::end(argumentLengths));
}
// Used by print_help.
size_t get_length_of_longest_argument() {
const auto positionalArgMaxSize = get_length_of_longest_argument(mPositionalArguments);
const auto optionalArgMaxSize = get_length_of_longest_argument(mOptionalArguments);
return std::max(positionalArgMaxSize, optionalArgMaxSize);
}
std::string mProgramName;
std::vector<ArgumentParser> mParentParsers;
std::vector<std::shared_ptr<Argument>> mPositionalArguments;
std::vector<std::shared_ptr<Argument>> mOptionalArguments;
size_t mNextPositionalArgument = 0;
std::map<std::string, std::shared_ptr<Argument>> mArgumentMap;
};

View File

@ -77,30 +77,7 @@ TEST_CASE("Parse compound toggle arguments with implicit values and nargs and ot
program.add_argument("--input_files")
.nargs(3);
program.parse_args({ "./test.exe", "1", "-abc", "3.14", "2.718", "2", "--input_files",
"a.txt", "b.txt", "c.txt", "3" });
REQUIRE(program.get<bool>("-a") == true);
REQUIRE(program.get<bool>("-b") == true);
auto c = program.get<std::vector<float>>("-c");
REQUIRE(c.size() == 2);
REQUIRE(c[0] == 3.14f);
REQUIRE(c[1] == 2.718f);
auto input_files = program.get<std::vector<std::string>>("--input_files");
REQUIRE(input_files.size() == 3);
REQUIRE(input_files[0] == "a.txt");
REQUIRE(input_files[1] == "b.txt");
REQUIRE(input_files[2] == "c.txt");
auto numbers = program.get<std::vector<int>>("numbers");
REQUIRE(numbers.size() == 3);
REQUIRE(numbers[0] == 1);
REQUIRE(numbers[1] == 2);
REQUIRE(numbers[2] == 3);
auto numbers_list = program.get<std::list<int>>("numbers");
REQUIRE(numbers.size() == 3);
REQUIRE(testutility::get_from_list(numbers_list, 0) == 1);
REQUIRE(testutility::get_from_list(numbers_list, 1) == 2);
REQUIRE(testutility::get_from_list(numbers_list, 2) == 3);
REQUIRE_THROWS(program.parse_args({ "./test.exe", "1", "-abc", "3.14", "2.718", "2", "--input_files", "a.txt", "b.txt", "c.txt", "3" }));
}
TEST_CASE("Parse out-of-order compound arguments", "[compound_arguments]") {

View File

@ -34,5 +34,5 @@ TEST_CASE("Parse unknown optional argument", "[compound_arguments]") {
.action([](const std::string& val) { return std::stoull(val); })
.help("memory in MB to give the VMM when loading");
bfm.parse_args({ "./test.exe", "-om" });
REQUIRE_THROWS(bfm.parse_args({ "./test.exe", "-om" }));
}

View File

@ -44,13 +44,7 @@ TEST_CASE("Parse positional arguments with optional arguments in the middle", "[
program.add_argument("output").nargs(2);
program.add_argument("--num_iterations")
.action([](const std::string& value) { return std::stoi(value); });
program.parse_args({ "test", "rocket.mesh", "thrust_profile.csv", "--num_iterations", "15", "output.mesh" });
REQUIRE(program.get<int>("--num_iterations") == 15);
REQUIRE(program.get("input") == "rocket.mesh");
auto outputs = program.get<std::vector<std::string>>("output");
REQUIRE(outputs.size() == 2);
REQUIRE(outputs[0] == "thrust_profile.csv");
REQUIRE(outputs[1] == "output.mesh");
REQUIRE_THROWS(program.parse_args({ "test", "rocket.mesh", "thrust_profile.csv", "--num_iterations", "15", "output.mesh" }));
}
TEST_CASE("Square a number", "[positional_arguments]") {