diff --git a/crates/nu-cli/src/commands/alias.rs b/crates/nu-cli/src/commands/alias.rs index 501844f386..a12b2c98d9 100644 --- a/crates/nu-cli/src/commands/alias.rs +++ b/crates/nu-cli/src/commands/alias.rs @@ -3,10 +3,14 @@ use crate::context::CommandRegistry; use crate::prelude::*; use nu_data::config; use nu_errors::ShellError; +use nu_parser::SignatureRegistry; +use nu_protocol::hir::{ClassifiedCommand, Expression, NamedValue, SpannedExpression, Variable}; use nu_protocol::{ - hir::Block, CommandAction, ReturnSuccess, Signature, SyntaxShape, UntaggedValue, Value, + hir::Block, CommandAction, NamedType, PositionalType, ReturnSuccess, Signature, SyntaxShape, + UntaggedValue, Value, }; use nu_source::Tagged; +use std::collections::HashMap; pub struct Alias; @@ -134,10 +138,183 @@ pub async fn alias( } Ok(OutputStream::one(ReturnSuccess::action( - CommandAction::AddAlias(name.to_string(), processed_args, block), + CommandAction::AddAlias( + name.to_string(), + to_arg_shapes(processed_args, &block, ®istry)?, + block, + ), ))) } +fn to_arg_shapes( + args: Vec, + block: &Block, + registry: &CommandRegistry, +) -> Result, ShellError> { + match find_block_shapes(block, registry) { + Ok(found) => Ok(args + .iter() + .map(|arg| { + ( + arg.clone(), + match found.get(arg) { + None | Some((_, None)) => SyntaxShape::Any, + Some((_, Some(shape))) => *shape, + }, + ) + }) + .collect()), + Err(err) => Err(err), + } +} + +type ShapeMap = HashMap)>; + +fn check_insert( + existing: &mut ShapeMap, + to_add: (String, (Span, Option)), +) -> Result<(), ShellError> { + match (to_add.1).1 { + None => match existing.get(&to_add.0) { + None => { + existing.insert(to_add.0, to_add.1); + Ok(()) + } + Some(_) => Ok(()), + }, + Some(new) => match existing.insert(to_add.0.clone(), ((to_add.1).0, Some(new))) { + None => Ok(()), + Some(exist) => match exist.1 { + None => Ok(()), + Some(shape) => match shape { + SyntaxShape::Any => Ok(()), + shape if shape == new => Ok(()), + _ => Err(ShellError::labeled_error( + "Type conflict in alias variable use", + "creates type conflict", + (to_add.1).0, + )), + }, + }, + }, + } +} + +fn check_merge(existing: &mut ShapeMap, new: &ShapeMap) -> Result<(), ShellError> { + for (k, v) in new.iter() { + check_insert(existing, (k.clone(), *v))?; + } + + Ok(()) +} + +fn find_expr_shapes( + spanned_expr: &SpannedExpression, + registry: &CommandRegistry, +) -> Result { + match &spanned_expr.expr { + // TODO range will need similar if/when invocations can be parsed within range expression + Expression::Binary(bin) => find_expr_shapes(&bin.left, registry).and_then(|mut left| { + find_expr_shapes(&bin.right, registry) + .and_then(|right| check_merge(&mut left, &right).map(|()| left)) + }), + Expression::Block(b) => find_block_shapes(&b, registry), + Expression::Path(path) => match &path.head.expr { + Expression::Invocation(b) => find_block_shapes(&b, registry), + Expression::Variable(Variable::Other(var, _)) => { + let mut result = HashMap::new(); + result.insert(var.to_string(), (spanned_expr.span, None)); + Ok(result) + } + _ => Ok(HashMap::new()), + }, + _ => Ok(HashMap::new()), + } +} + +fn find_block_shapes(block: &Block, registry: &CommandRegistry) -> Result { + let apply_shape = |found: ShapeMap, sig_shape: SyntaxShape| -> ShapeMap { + found + .iter() + .map(|(v, sh)| match sh.1 { + None => (v.clone(), (sh.0, Some(sig_shape))), + Some(shape) => (v.clone(), (sh.0, Some(shape))), + }) + .collect() + }; + + let mut arg_shapes = HashMap::new(); + for pipeline in &block.block { + for classified in &pipeline.list { + match classified { + ClassifiedCommand::Expr(spanned_expr) => { + let found = find_expr_shapes(&spanned_expr, registry)?; + check_merge(&mut arg_shapes, &found)? + } + ClassifiedCommand::Internal(internal) => { + if let Some(signature) = registry.get(&internal.name) { + if let Some(positional) = &internal.args.positional { + for (i, spanned_expr) in positional.iter().enumerate() { + let found = find_expr_shapes(&spanned_expr, registry)?; + if i >= signature.positional.len() { + if let Some((sig_shape, _)) = &signature.rest_positional { + check_merge( + &mut arg_shapes, + &apply_shape(found, *sig_shape), + )?; + } else { + unreachable!("should have error'd in parsing"); + } + } else { + let (pos_type, _) = &signature.positional[i]; + match pos_type { + // TODO pass on mandatory/optional? + PositionalType::Mandatory(_, sig_shape) + | PositionalType::Optional(_, sig_shape) => { + check_merge( + &mut arg_shapes, + &apply_shape(found, *sig_shape), + )?; + } + } + } + } + } + + if let Some(named) = &internal.args.named { + for (name, val) in named.iter() { + if let NamedValue::Value(_, spanned_expr) = val { + let found = find_expr_shapes(&spanned_expr, registry)?; + match signature.named.get(name) { + None => { + unreachable!("should have error'd in parsing"); + } + Some((named_type, _)) => { + if let NamedType::Mandatory(_, sig_shape) + | NamedType::Optional(_, sig_shape) = named_type + { + check_merge( + &mut arg_shapes, + &apply_shape(found, *sig_shape), + )?; + } + } + } + } + } + } + } else { + unreachable!("registry has lost name it provided"); + } + } + ClassifiedCommand::Dynamic(_) | ClassifiedCommand::Error(_) => (), + } + } + } + + Ok(arg_shapes) +} + #[cfg(test)] mod tests { use super::Alias; diff --git a/crates/nu-cli/src/commands/run_alias.rs b/crates/nu-cli/src/commands/run_alias.rs index 74915e0e37..4d161c8c0b 100644 --- a/crates/nu-cli/src/commands/run_alias.rs +++ b/crates/nu-cli/src/commands/run_alias.rs @@ -9,7 +9,7 @@ use nu_protocol::{hir::Block, Signature, SyntaxShape}; #[derive(new, Clone)] pub struct AliasCommand { name: String, - args: Vec, + args: Vec<(String, SyntaxShape)>, block: Block, } @@ -22,8 +22,8 @@ impl WholeStreamCommand for AliasCommand { fn signature(&self) -> Signature { let mut alias = Signature::build(&self.name); - for arg in &self.args { - alias = alias.optional(arg, SyntaxShape::Any, ""); + for (arg, shape) in &self.args { + alias = alias.optional(arg, *shape, ""); } alias @@ -53,7 +53,7 @@ impl WholeStreamCommand for AliasCommand { for (pos, arg) in positional.iter().enumerate() { scope .vars - .insert(alias_command.args[pos].to_string(), arg.clone()); + .insert(alias_command.args[pos].0.to_string(), arg.clone()); } } diff --git a/crates/nu-cli/tests/commands/alias.rs b/crates/nu-cli/tests/commands/alias.rs index acda5b8f2c..27b685983f 100644 --- a/crates/nu-cli/tests/commands/alias.rs +++ b/crates/nu-cli/tests/commands/alias.rs @@ -15,3 +15,85 @@ fn alias_args_work() { assert_eq!(actual.out, "[1,2]"); }) } + +#[test] +#[cfg(not(windows))] +fn alias_parses_path_tilde() { + let actual = nu!( + cwd: "tests/fixtures/formats", + r#" + alias new-cd [dir] { cd $dir } + new-cd ~ + pwd + "# + ); + + #[cfg(target_os = "linux")] + assert!(actual.out.contains("home")); + #[cfg(target_os = "macos")] + assert!(actual.out.contains("Users")); +} + +#[test] +fn error_alias_wrong_shape_shallow() { + let actual = nu!( + cwd: ".", + r#" + alias round-to [num digits] { echo $num | str from -d $digits } + round-to 3.45 a + "# + ); + + assert!(actual.err.contains("Type")); +} + +#[test] +fn error_alias_wrong_shape_deep_invocation() { + let actual = nu!( + cwd: ".", + r#" + alias round-to [nums digits] { echo $nums | each {= $(str from -d $digits)}} + round-to 3.45 a + "# + ); + + assert!(actual.err.contains("Type")); +} + +#[test] +fn error_alias_wrong_shape_deep_binary() { + let actual = nu!( + cwd: ".", + r#" + alias round-plus-one [nums digits] { echo $nums | each {= $(str from -d $digits | str to-decimal) + 1}} + round-plus-one 3.45 a + "# + ); + + assert!(actual.err.contains("Type")); +} + +#[test] +fn error_alias_wrong_shape_deeper_binary() { + let actual = nu!( + cwd: ".", + r#" + alias round-one-more [num digits] { echo $num | str from -d $(= $digits + 1) } + round-one-more 3.45 a + "# + ); + + assert!(actual.err.contains("Type")); +} + +#[test] +fn error_alias_syntax_shape_clash() { + let actual = nu!( + cwd: ".", + r#" + alias clash [a] { echo 1.1 2 3 | each { str from -d $a } | range $a } } + "# + ); + + assert!(actual.err.contains("alias")); +} diff --git a/crates/nu-protocol/src/return_value.rs b/crates/nu-protocol/src/return_value.rs index dc17f6979e..37cf0f8b62 100644 --- a/crates/nu-protocol/src/return_value.rs +++ b/crates/nu-protocol/src/return_value.rs @@ -1,5 +1,6 @@ use crate::hir::Block; use crate::value::Value; +use crate::SyntaxShape; use nu_errors::ShellError; use nu_source::{b, DebugDocBuilder, PrettyDebug}; use serde::{Deserialize, Serialize}; @@ -22,7 +23,7 @@ pub enum CommandAction { /// Enter the help shell, which allows exploring the help system EnterHelpShell(Value), /// Enter the help shell, which allows exploring the help system - AddAlias(String, Vec, Block), + AddAlias(String, Vec<(String, SyntaxShape)>, Block), /// Go to the previous shell in the shell ring buffer PreviousShell, /// Go to the next shell in the shell ring buffer diff --git a/crates/nu-protocol/src/syntax_shape.rs b/crates/nu-protocol/src/syntax_shape.rs index 4b788421bf..dc4cb73770 100644 --- a/crates/nu-protocol/src/syntax_shape.rs +++ b/crates/nu-protocol/src/syntax_shape.rs @@ -2,7 +2,7 @@ use nu_source::{b, DebugDocBuilder, PrettyDebug}; use serde::{Deserialize, Serialize}; /// The syntactic shapes that values must match to be passed into a command. You can think of this as the type-checking that occurs when you call a function. -#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq)] pub enum SyntaxShape { /// Any syntactic form is allowed Any,