diff --git a/Cargo.lock b/Cargo.lock index 0fe327602a..77671ba8e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -608,6 +608,7 @@ dependencies = [ "nu-protocol", "nu-table", "nu-term-grid", + "rayon", "sysinfo", "terminal_size", "thiserror", diff --git a/crates/nu-command/Cargo.toml b/crates/nu-command/Cargo.toml index 352d7daf45..1a805e4c6a 100644 --- a/crates/nu-command/Cargo.toml +++ b/crates/nu-command/Cargo.toml @@ -26,6 +26,7 @@ terminal_size = "0.1.17" lscolors = { version = "0.8.0", features = ["crossterm"] } bytesize = "1.1.0" dialoguer = "0.9.0" +rayon = "1.5.1" [features] trash-support = ["trash"] diff --git a/crates/nu-command/src/default_context.rs b/crates/nu-command/src/default_context.rs index b1ba99470b..6fe0b3b13f 100644 --- a/crates/nu-command/src/default_context.rs +++ b/crates/nu-command/src/default_context.rs @@ -51,6 +51,7 @@ pub fn create_default_context() -> EngineState { Mkdir, Module, Mv, + ParEach, Ps, Rm, Select, diff --git a/crates/nu-command/src/filters/mod.rs b/crates/nu-command/src/filters/mod.rs index e4e68902da..bb87c7944b 100644 --- a/crates/nu-command/src/filters/mod.rs +++ b/crates/nu-command/src/filters/mod.rs @@ -2,6 +2,7 @@ mod each; mod get; mod length; mod lines; +mod par_each; mod select; mod where_; mod wrap; @@ -10,6 +11,7 @@ pub use each::Each; pub use get::Get; pub use length::Length; pub use lines::Lines; +pub use par_each::ParEach; pub use select::Select; pub use where_::Where; pub use wrap::Wrap; diff --git a/crates/nu-command/src/filters/par_each.rs b/crates/nu-command/src/filters/par_each.rs new file mode 100644 index 0000000000..79ecab43ec --- /dev/null +++ b/crates/nu-command/src/filters/par_each.rs @@ -0,0 +1,252 @@ +use nu_engine::eval_block; +use nu_protocol::ast::Call; +use nu_protocol::engine::{Command, EngineState, Stack}; +use nu_protocol::{Example, IntoPipelineData, PipelineData, Signature, SyntaxShape, Value}; +use rayon::prelude::*; + +#[derive(Clone)] +pub struct ParEach; + +impl Command for ParEach { + fn name(&self) -> &str { + "par-each" + } + + fn usage(&self) -> &str { + "Run a block on each element of input in parallel" + } + + fn signature(&self) -> nu_protocol::Signature { + Signature::build("each") + .required( + "block", + SyntaxShape::Block(Some(vec![SyntaxShape::Any])), + "the block to run", + ) + .switch("numbered", "iterate with an index", Some('n')) + } + + fn examples(&self) -> Vec { + vec![Example { + example: "[1 2 3] | each { 2 * $it }", + description: "Multiplies elements in list", + result: None, + }] + } + + fn run( + &self, + engine_state: &EngineState, + stack: &mut Stack, + call: &Call, + input: PipelineData, + ) -> Result { + let block_id = call.positional[0] + .as_block() + .expect("internal error: expected block"); + + let numbered = call.has_flag("numbered"); + let engine_state = engine_state.clone(); + let block = engine_state.get_block(block_id); + let mut stack = stack.collect_captures(&block.captures); + let span = call.head; + + match input { + PipelineData::Value(Value::Range { val, .. }) => Ok(val + .into_range_iter()? + .enumerate() + .par_bridge() + .map(move |(idx, x)| { + let block = engine_state.get_block(block_id); + + let mut stack = stack.clone(); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + if numbered { + stack.add_var( + *var_id, + Value::Record { + cols: vec!["index".into(), "item".into()], + vals: vec![ + Value::Int { + val: idx as i64, + span, + }, + x, + ], + span, + }, + ); + } else { + stack.add_var(*var_id, x); + } + } + } + + match eval_block(&engine_state, &mut stack, block, PipelineData::new()) { + Ok(v) => v, + Err(error) => Value::Error { error }.into_pipeline_data(), + } + }) + .collect::>() + .into_iter() + .flatten() + .into_pipeline_data()), + PipelineData::Value(Value::List { vals: val, .. }) => Ok(val + .into_iter() + .enumerate() + .par_bridge() + .map(move |(idx, x)| { + let block = engine_state.get_block(block_id); + + let mut stack = stack.clone(); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + if numbered { + stack.add_var( + *var_id, + Value::Record { + cols: vec!["index".into(), "item".into()], + vals: vec![ + Value::Int { + val: idx as i64, + span, + }, + x, + ], + span, + }, + ); + } else { + stack.add_var(*var_id, x); + } + } + } + + match eval_block(&engine_state, &mut stack, block, PipelineData::new()) { + Ok(v) => v, + Err(error) => Value::Error { error }.into_pipeline_data(), + } + }) + .collect::>() + .into_iter() + .flatten() + .into_pipeline_data()), + PipelineData::Stream(stream) => Ok(stream + .enumerate() + .par_bridge() + .map(move |(idx, x)| { + let block = engine_state.get_block(block_id); + + let mut stack = stack.clone(); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + if numbered { + stack.add_var( + *var_id, + Value::Record { + cols: vec!["index".into(), "item".into()], + vals: vec![ + Value::Int { + val: idx as i64, + span, + }, + x, + ], + span, + }, + ); + } else { + stack.add_var(*var_id, x); + } + } + } + + match eval_block(&engine_state, &mut stack, block, PipelineData::new()) { + Ok(v) => v, + Err(error) => Value::Error { error }.into_pipeline_data(), + } + }) + .collect::>() + .into_iter() + .flatten() + .into_pipeline_data()), + PipelineData::Value(Value::Record { cols, vals, .. }) => { + let mut output_cols = vec![]; + let mut output_vals = vec![]; + + for (col, val) in cols.into_iter().zip(vals.into_iter()) { + let block = engine_state.get_block(block_id); + + let mut stack = stack.clone(); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var( + *var_id, + Value::Record { + cols: vec!["column".into(), "value".into()], + vals: vec![ + Value::String { + val: col.clone(), + span: call.head, + }, + val, + ], + span: call.head, + }, + ); + } + } + + match eval_block(&engine_state, &mut stack, block, PipelineData::new())? { + PipelineData::Value(Value::Record { + mut cols, mut vals, .. + }) => { + // TODO check that the lengths match when traversing record + output_cols.append(&mut cols); + output_vals.append(&mut vals); + } + x => { + output_cols.push(col); + output_vals.push(x.into_value()); + } + } + } + + Ok(Value::Record { + cols: output_cols, + vals: output_vals, + span: call.head, + } + .into_pipeline_data()) + } + PipelineData::Value(x) => { + let block = engine_state.get_block(block_id); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var(*var_id, x); + } + } + + eval_block(&engine_state, &mut stack, block, PipelineData::new()) + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_examples() { + use crate::test_examples; + + test_examples(ParEach {}) + } +} diff --git a/src/tests.rs b/src/tests.rs index f990e6cc87..fa2d323065 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -621,6 +621,14 @@ fn for_loops() -> TestResult { run_test(r#"(for x in [1, 2, 3] { $x + 10 }).1"#, "12") } +#[test] +fn par_each() -> TestResult { + run_test( + r#"1..10 | par-each --numbered { ([[index, item]; [$it.index, ($it.item > 5)]]).0 } | where index == 4 | get item.0"#, + "false", + ) +} + #[test] fn type_in_list_of_this_type() -> TestResult { run_test(r#"42 in [41 42 43]"#, "true")