diff --git a/crates/nu-parser/src/flag.rs b/crates/nu-parser/src/flag.rs new file mode 100644 index 0000000000..bd300e009a --- /dev/null +++ b/crates/nu-parser/src/flag.rs @@ -0,0 +1,138 @@ +use nu_errors::{ArgumentError, ParseError}; +use nu_protocol::hir::InternalCommand; +use nu_protocol::NamedType; +use nu_source::{Span, Spanned, SpannedItem}; + +/// Match the available flags in a signature with what the user provided. This will check both long-form flags (--long) and shorthand flags (-l) +/// This also allows users to provide a group of shorthand flags (-la) that correspond to multiple shorthand flags at once. +pub fn get_flag_signature_spec( + signature: &nu_protocol::Signature, + cmd: &InternalCommand, + arg: &Spanned, +) -> (Vec<(String, NamedType)>, Option) { + if arg.item.starts_with('-') { + // It's a flag (or set of flags) + let mut output = vec![]; + let mut error = None; + + let remainder: String = arg.item.chars().skip(1).collect(); + + if remainder.starts_with('-') { + // Long flag expected + let mut remainder: String = remainder.chars().skip(1).collect(); + + if remainder.contains('=') { + let assignment: Vec<_> = remainder.split('=').collect(); + + if assignment.len() != 2 { + error = Some(ParseError::argument_error( + cmd.name.to_string().spanned(cmd.name_span), + ArgumentError::InvalidExternalWord, + )); + } else { + remainder = assignment[0].to_string(); + } + } + + if let Some((named_type, _)) = signature.named.get(&remainder) { + output.push((remainder.clone(), named_type.clone())); + } else { + error = Some(ParseError::argument_error( + cmd.name.to_string().spanned(cmd.name_span), + ArgumentError::UnexpectedFlag(arg.clone()), + )); + } + } else { + // Short flag(s) expected + let mut starting_pos = arg.span.start() + 1; + for c in remainder.chars() { + let mut found = false; + for (full_name, named_arg) in signature.named.iter() { + if Some(c) == named_arg.0.get_short() { + found = true; + output.push((full_name.clone(), named_arg.0.clone())); + break; + } + } + + if !found { + error = Some(ParseError::argument_error( + cmd.name.to_string().spanned(cmd.name_span), + ArgumentError::UnexpectedFlag( + arg.item + .clone() + .spanned(Span::new(starting_pos, starting_pos + c.len_utf8())), + ), + )); + } + + starting_pos += c.len_utf8(); + } + } + + (output, error) + } else { + // It's not a flag, so don't bother with it + (vec![], None) + } +} + +#[cfg(test)] +mod tests { + use super::get_flag_signature_spec; + use crate::{lex, parse_block}; + use nu_protocol::{hir::InternalCommand, NamedType, Signature, SyntaxShape}; + use nu_source::{HasSpan, Span}; + + fn bundle() -> Signature { + Signature::build("bundle add") + .switch("skip-install", "Adds the gem to the Gemfile but does not install it.", None) + .named("group", SyntaxShape::String, "Specify the group(s) for the added gem. Multiple groups should be separated by commas.", Some('g')) + .rest(SyntaxShape::Any, "options") + } + + #[test] + fn parses_longform_flag_containing_equal_sign() { + let input = "bundle add rails --group=development"; + let (tokens, _) = lex(&input, 0); + let (root_node, _) = parse_block(tokens); + + assert_eq!(root_node.block.len(), 1); + assert_eq!(root_node.block[0].pipelines.len(), 1); + assert_eq!(root_node.block[0].pipelines[0].commands.len(), 1); + assert_eq!(root_node.block[0].pipelines[0].commands[0].parts.len(), 4); + + let command_node = root_node.block[0].pipelines[0].commands[0].clone(); + let idx = 1; + + let (name, name_span) = ( + command_node.parts[0..(idx + 1)] + .iter() + .map(|x| x.item.clone()) + .collect::>() + .join(" "), + Span::new( + command_node.parts[0].span.start(), + command_node.parts[idx].span.end(), + ), + ); + + let mut internal = InternalCommand::new(name, name_span, command_node.span()); + + let signature = bundle(); + + internal.args.set_initial_flags(&signature); + + let (flags, err) = get_flag_signature_spec(&signature, &internal, &command_node.parts[3]); + let (long_name, spec) = flags[0].clone(); + + assert!(err.is_none()); + assert_eq!(long_name, "group".to_string()); + assert_eq!(spec.get_short(), Some('g')); + + match spec { + NamedType::Optional(_, _) => {} + _ => panic!("optional flag didn't parse succesfully"), + } + } +} diff --git a/crates/nu-parser/src/lib.rs b/crates/nu-parser/src/lib.rs index 92f728090b..9ad9768598 100644 --- a/crates/nu-parser/src/lib.rs +++ b/crates/nu-parser/src/lib.rs @@ -4,6 +4,7 @@ extern crate derive_is_enum_variant; extern crate derive_new; mod errors; +mod flag; mod lex; mod parse; mod path; diff --git a/crates/nu-parser/src/parse.rs b/crates/nu-parser/src/parse.rs index bcf88b4999..d867b1eb63 100644 --- a/crates/nu-parser/src/parse.rs +++ b/crates/nu-parser/src/parse.rs @@ -1110,66 +1110,6 @@ fn parse_arg( } } -/// Match the available flags in a signature with what the user provided. This will check both long-form flags (--long) and shorthand flags (-l) -/// This also allows users to provide a group of shorthand flags (-la) that correspond to multiple shorthand flags at once. -fn get_flags_from_flag( - signature: &nu_protocol::Signature, - cmd: &InternalCommand, - arg: &Spanned, -) -> (Vec<(String, NamedType)>, Option) { - if arg.item.starts_with('-') { - // It's a flag (or set of flags) - let mut output = vec![]; - let mut error = None; - - let remainder: String = arg.item.chars().skip(1).collect(); - - if remainder.starts_with('-') { - // Long flag expected - let remainder: String = remainder.chars().skip(1).collect(); - if let Some((named_type, _)) = signature.named.get(&remainder) { - output.push((remainder.clone(), named_type.clone())); - } else { - error = Some(ParseError::argument_error( - cmd.name.to_string().spanned(cmd.name_span), - ArgumentError::UnexpectedFlag(arg.clone()), - )); - } - } else { - // Short flag(s) expected - let mut starting_pos = arg.span.start() + 1; - for c in remainder.chars() { - let mut found = false; - for (full_name, named_arg) in signature.named.iter() { - if Some(c) == named_arg.0.get_short() { - found = true; - output.push((full_name.clone(), named_arg.0.clone())); - break; - } - } - - if !found { - error = Some(ParseError::argument_error( - cmd.name.to_string().spanned(cmd.name_span), - ArgumentError::UnexpectedFlag( - arg.item - .clone() - .spanned(Span::new(starting_pos, starting_pos + c.len_utf8())), - ), - )); - } - - starting_pos += c.len_utf8(); - } - } - - (output, error) - } else { - // It's not a flag, so don't bother with it - (vec![], None) - } -} - /// This is a bit of a "fix-up" of previously parsed areas. In cases where we're in shorthand mode (eg in the `where` command), we need /// to use the original source to parse a column path. Without it, we'll lose a little too much information to parse it correctly. As we'll /// only know we were on the left-hand side of an expression after we do the full math parse, we need to do this step after rather than during @@ -1486,14 +1426,53 @@ fn parse_internal_command( while idx < lite_cmd.parts.len() { if lite_cmd.parts[idx].item.starts_with('-') && lite_cmd.parts[idx].item.len() > 1 { - let (named_types, err) = - get_flags_from_flag(&signature, &internal_command, &lite_cmd.parts[idx]); + let (named_types, err) = super::flag::get_flag_signature_spec( + &signature, + &internal_command, + &lite_cmd.parts[idx], + ); if err.is_none() { for (full_name, named_type) in &named_types { match named_type { NamedType::Mandatory(_, shape) | NamedType::Optional(_, shape) => { - if idx == lite_cmd.parts.len() { + if lite_cmd.parts[idx].item.contains('=') { + let mut offset = 0; + + lite_cmd.parts[idx] + .item + .chars() + .skip_while(|prop| { + offset += 1; + *prop != '=' + }) + .skip(1) + .for_each(drop); + + let flag_value = Span::new_option( + lite_cmd.parts[idx].span.start() + + (lite_cmd.parts[idx].span.start() - offset), + lite_cmd.parts[idx].span.end(), + ); + + if let Some(value_span) = flag_value { + let value = lite_cmd.parts[idx].item[offset..] + .to_string() + .spanned(value_span); + + let (arg, err) = parse_arg(*shape, scope, &value); + + named.insert_mandatory( + full_name.clone(), + lite_cmd.parts[idx].span, + arg, + ); + + if error.is_none() { + error = err; + } + } + } else if idx == lite_cmd.parts.len() { // Oops, we're missing the argument to our named argument if error.is_none() { error = Some(ParseError::argument_error( diff --git a/crates/nu-source/src/meta.rs b/crates/nu-source/src/meta.rs index 370fff1ef8..6fb0fa10c7 100644 --- a/crates/nu-source/src/meta.rs +++ b/crates/nu-source/src/meta.rs @@ -525,6 +525,14 @@ impl Span { Span { start, end } } + pub fn new_option(start: usize, end: usize) -> Option { + if end >= start { + None + } else { + Some(Span { start, end }) + } + } + /// Creates a `Span` with a length of 1 from the given position. /// /// # Example