diff --git a/src/argparse.lua b/src/argparse.lua index f45e68f..0906d5a 100644 --- a/src/argparse.lua +++ b/src/argparse.lua @@ -208,11 +208,36 @@ function Parser:make_command_names() end end +function Parser:make_types() + for _, elements in ipairs{self.arguments, self.options} do + for _, element in ipairs(elements) do + if element.maxcount == 1 then + if element.maxargs == 0 then + element.type = "flag" + elseif element.maxargs == 1 and element.minargs == 1 then + element.type = "arg" + else + element.type = "multi-arg" + end + else + if element.maxargs == 0 then + element.type = "counter" + elseif element.maxargs == 1 and element.minargs == 1 then + element.type = "multi-count" + else + element.type = "multi-count multi-arg" + end + end + end + end +end + function Parser:prepare() self:make_charset() self:make_targets() self:make_boundaries() self:make_command_names() + self:make_types() return self end @@ -228,18 +253,86 @@ function Parser:parse(args) local opt_context = {} local com_context local result = {} + local invocations = {} + local passed = {} local cur_option local cur_arg_i = 1 local cur_arg - local function close(element) - local invocations = result[element.target] - local passed = invocations[#invocations] + local function convert(element, data) + if element.convert then + local ok, err = element.convert(data) - if #passed < element.minargs then + return parser:assert(ok, "%s", err or "malformed argument " .. data) + else + return data + end + end + + local invoke, pass, close + + function invoke(element) + local overwrite = false + + if invocations[element] == element.maxcount then + if element.overwrite then + overwrite = true + else + parser:error("option %s must be used at most %d times", element.name, element.maxcount) + end + else + invocations[element] = invocations[element]+1 + end + + passed[element] = 0 + + if element.type == "flag" then + result[element.target] = true + elseif element.type == "multi-arg" then + result[element.target] = {} + elseif element.type == "counter" then + if not overwrite then + result[element.target] = result[element.target]+1 + end + elseif element.type == "multi-count" then + if overwrite then + table.remove(result[element.target], 1) + end + elseif element.type == "multi-count multi-arg" then + table.insert(result[element.target], {}) + + if overwrite then + table.remove(result[element.target], 1) + end + end + + if element.maxargs == 0 then + close(element) + end + end + + function pass(element, data) + passed[element] = passed[element]+1 + data = convert(element, data) + + if element.type == "arg" then + result[element.target] = data + elseif element.type == "multi-arg" or element.type == "multi-count" then + table.insert(result[element.target], data) + elseif element.type == "multi-count multi-arg" then + table.insert(result[element.target][#result[element.target]], data) + end + + if passed[element] == element.maxargs then + close(element) + end + end + + function close(element) + if passed[element] < element.minargs then if element.default then - while #passed < element.minargs do - table.insert(passed, element.default) + while passed[element] < element.minargs do + pass(element, element.default) end else parser:error("too few arguments") @@ -254,34 +347,6 @@ function Parser:parse(args) end end - local function invoke(element) - local invocations = result[element.target] - - if #invocations == element.maxcount then - if element.overwrite then - table.remove(invocations, 1) - else - parser:error("option %s must be used at most %d times", element.name, element.maxcount) - end - end - - table.insert(result[element.target], {}) - - if element.maxargs == 0 then - close(element) - end - end - - local function pass(element, data) - local invocations = result[element.target] - local passed = invocations[#invocations] - table.insert(passed, data) - - if #passed == element.maxargs then - close(element) - end - end - local function switch(p) parser = p:prepare() charset = p.charset @@ -293,12 +358,18 @@ function Parser:parse(args) opt_context[alias] = option end - result[option.target] = {} + if option.type == "counter" then + result[option.target] = 0 + elseif option.type == "multi-count" or option.type == "multi-count multi-arg" then + result[option.target] = {} + end + + invocations[option] = 0 end for _, argument in ipairs(p.arguments) do table.insert(arguments, argument) - result[argument.target] = {} + invocations[argument] = 0 invoke(argument) end @@ -398,67 +469,6 @@ function Parser:parse(args) end end - local function convert(element, data) - if element.convert then - local ok, err = element.convert(data) - - return parser:assert(ok, "%s", err or "malformed argument " .. data) - else - return data - end - end - - local function format() - local invocations - - for _, elements in ipairs{options, arguments} do - for _, element in ipairs(elements) do - invocations = result[element.target] - - parser:assert(#invocations >= element.mincount, - "option %s must be used at least %d times", element.name, element.mincount) - - if element.maxcount == 1 then - if element.maxargs == 0 then - if #invocations > 0 then - result[element.target] = true - else - result[element.target] = nil - end - elseif element.maxargs == 1 and element.minargs == 1 then - if #invocations > 0 then - result[element.target] = convert(element, invocations[1][1]) - else - result[element.target] = nil - end - else - result[element.target] = invocations[1] - - if #invocations > 0 and element.convert then - for i=1, #result[element.target] do - result[element.target][i] = convert(element, result[element.target][i]) - end - end - end - else - if element.maxargs == 0 then - result[element.target] = #invocations - elseif element.maxargs == 1 and element.minargs == 1 then - for i=1, #invocations do - invocations[i] = convert(element, invocations[i][1]) - end - elseif element.convert then - for _, invocation in ipairs(invocations) do - for i=1, #invocation do - invocation[i] = convert(element, invocation[i]) - end - end - end - end - end - end - end - switch(self) mainloop() @@ -474,7 +484,11 @@ function Parser:parse(args) parser:error("command is required") end - format() + for _, option in ipairs(options) do + parser:assert(invocations[option] >= option.mincount, + "option %s must be used at least %d times", option.name, option.mincount + ) + end return result end