From 724b177c97016a3da1b9da3aa7840ee267d3722f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20N=2E=20Robalino?= Date: Thu, 6 Aug 2020 23:56:19 -0500 Subject: [PATCH] Sample variance and Sample standard deviation. (#2310) --- crates/nu-cli/src/commands/math/stddev.rs | 101 ++++++++++++++++---- crates/nu-cli/src/commands/math/variance.rs | 97 +++++++++++++++---- 2 files changed, 161 insertions(+), 37 deletions(-) diff --git a/crates/nu-cli/src/commands/math/stddev.rs b/crates/nu-cli/src/commands/math/stddev.rs index a350db7d0f..2e087d0f11 100644 --- a/crates/nu-cli/src/commands/math/stddev.rs +++ b/crates/nu-cli/src/commands/math/stddev.rs @@ -1,13 +1,18 @@ -use super::variance::variance; -use crate::commands::math::utils::run_with_function; +use super::variance::compute_variance as variance; use crate::commands::WholeStreamCommand; use crate::prelude::*; use nu_errors::ShellError; -use nu_protocol::{Primitive, Signature, UntaggedValue, Value}; +use nu_protocol::{Dictionary, Primitive, ReturnSuccess, Signature, UntaggedValue, Value}; +use nu_source::Tagged; use std::str::FromStr; pub struct SubCommand; +#[derive(Deserialize)] +struct Arguments { + sample: Tagged, +} + #[async_trait] impl WholeStreamCommand for SubCommand { fn name(&self) -> &str { @@ -15,7 +20,11 @@ impl WholeStreamCommand for SubCommand { } fn signature(&self) -> Signature { - Signature::build("math stddev") + Signature::build("math stddev").switch( + "sample", + "calculate sample standard deviation", + Some('s'), + ) } fn usage(&self) -> &str { @@ -27,20 +36,69 @@ impl WholeStreamCommand for SubCommand { args: CommandArgs, registry: &CommandRegistry, ) -> Result { - run_with_function( - RunnableContext { - input: args.input, - registry: registry.clone(), - shell_manager: args.shell_manager, - host: args.host, - ctrl_c: args.ctrl_c, - current_errors: args.current_errors, - name: args.call_info.name_tag, - raw_input: args.raw_input, - }, - stddev, - ) - .await + let name = args.call_info.name_tag.clone(); + let (Arguments { sample }, mut input) = args.process(®istry).await?; + + let values: Vec = input.drain_vec().await; + + let n = if let Tagged { item: true, .. } = sample { + values.len() - 1 + } else { + values.len() + }; + + let res = if values.iter().all(|v| v.is_primitive()) { + compute_stddev(&values, n, &name) + } else { + // If we are not dealing with Primitives, then perhaps we are dealing with a table + // Create a key for each column name + let mut column_values = IndexMap::new(); + for value in values { + if let UntaggedValue::Row(row_dict) = &value.value { + for (key, value) in row_dict.entries.iter() { + column_values + .entry(key.clone()) + .and_modify(|v: &mut Vec| v.push(value.clone())) + .or_insert(vec![value.clone()]); + } + } + } + // The mathematical function operates over the columns of the table + let mut column_totals = IndexMap::new(); + for (col_name, col_vals) in column_values { + if let Ok(out) = compute_stddev(&col_vals, n, &name) { + column_totals.insert(col_name, out); + } + } + + if column_totals.keys().len() == 0 { + return Err(ShellError::labeled_error( + "Attempted to compute values that can't be operated on", + "value appears here", + name.span, + )); + } + + Ok(UntaggedValue::Row(Dictionary { + entries: column_totals, + }) + .into_untagged_value()) + }; + + match res { + Ok(v) => { + if v.value.is_table() { + Ok(OutputStream::from( + v.table_entries() + .map(|v| ReturnSuccess::value(v.clone())) + .collect::>(), + )) + } else { + Ok(OutputStream::one(ReturnSuccess::value(v))) + } + } + Err(e) => Err(e), + } } fn examples(&self) -> Vec { @@ -52,8 +110,13 @@ impl WholeStreamCommand for SubCommand { } } +#[cfg(test)] pub fn stddev(values: &[Value], name: &Tag) -> Result { - let variance = variance(values, name)?.as_primitive()?; + compute_stddev(values, values.len(), name) +} + +pub fn compute_stddev(values: &[Value], n: usize, name: &Tag) -> Result { + let variance = variance(values, n, name)?.as_primitive()?; let sqrt_var = match variance { Primitive::Decimal(var) => var.sqrt(), _ => { diff --git a/crates/nu-cli/src/commands/math/variance.rs b/crates/nu-cli/src/commands/math/variance.rs index 82a9205288..59dda37ec6 100644 --- a/crates/nu-cli/src/commands/math/variance.rs +++ b/crates/nu-cli/src/commands/math/variance.rs @@ -1,13 +1,20 @@ -use crate::commands::math::utils::run_with_function; use crate::commands::WholeStreamCommand; use crate::data::value::compute_values; use crate::prelude::*; use bigdecimal::FromPrimitive; use nu_errors::ShellError; -use nu_protocol::{hir::Operator, Primitive, Signature, UntaggedValue, Value}; +use nu_protocol::{ + hir::Operator, Dictionary, Primitive, ReturnSuccess, Signature, UntaggedValue, Value, +}; +use nu_source::Tagged; pub struct SubCommand; +#[derive(Deserialize)] +struct Arguments { + sample: Tagged, +} + #[async_trait] impl WholeStreamCommand for SubCommand { fn name(&self) -> &str { @@ -15,7 +22,7 @@ impl WholeStreamCommand for SubCommand { } fn signature(&self) -> Signature { - Signature::build("math variance") + Signature::build("math variance").switch("sample", "calculate sample variance", Some('s')) } fn usage(&self) -> &str { @@ -27,20 +34,69 @@ impl WholeStreamCommand for SubCommand { args: CommandArgs, registry: &CommandRegistry, ) -> Result { - run_with_function( - RunnableContext { - input: args.input, - registry: registry.clone(), - shell_manager: args.shell_manager, - host: args.host, - ctrl_c: args.ctrl_c, - current_errors: args.current_errors, - name: args.call_info.name_tag, - raw_input: args.raw_input, - }, - variance, - ) - .await + let name = args.call_info.name_tag.clone(); + let (Arguments { sample }, mut input) = args.process(®istry).await?; + + let values: Vec = input.drain_vec().await; + + let n = if let Tagged { item: true, .. } = sample { + values.len() - 1 + } else { + values.len() + }; + + let res = if values.iter().all(|v| v.is_primitive()) { + compute_variance(&values, n, &name) + } else { + // If we are not dealing with Primitives, then perhaps we are dealing with a table + // Create a key for each column name + let mut column_values = IndexMap::new(); + for value in values { + if let UntaggedValue::Row(row_dict) = &value.value { + for (key, value) in row_dict.entries.iter() { + column_values + .entry(key.clone()) + .and_modify(|v: &mut Vec| v.push(value.clone())) + .or_insert(vec![value.clone()]); + } + } + } + // The mathematical function operates over the columns of the table + let mut column_totals = IndexMap::new(); + for (col_name, col_vals) in column_values { + if let Ok(out) = compute_variance(&col_vals, n, &name) { + column_totals.insert(col_name, out); + } + } + + if column_totals.keys().len() == 0 { + return Err(ShellError::labeled_error( + "Attempted to compute values that can't be operated on", + "value appears here", + name.span, + )); + } + + Ok(UntaggedValue::Row(Dictionary { + entries: column_totals, + }) + .into_untagged_value()) + }; + + match res { + Ok(v) => { + if v.value.is_table() { + Ok(OutputStream::from( + v.table_entries() + .map(|v| ReturnSuccess::value(v.clone())) + .collect::>(), + )) + } else { + Ok(OutputStream::one(ReturnSuccess::value(v))) + } + } + Err(e) => Err(e), + } } fn examples(&self) -> Vec { @@ -147,9 +203,14 @@ fn sum_of_squares(values: &[Value], name: &Tag) -> Result { Ok(ss) } +#[cfg(test)] pub fn variance(values: &[Value], name: &Tag) -> Result { + compute_variance(values, values.len(), name) +} + +pub fn compute_variance(values: &[Value], n: usize, name: &Tag) -> Result { let ss = sum_of_squares(values, name)?; - let n = BigDecimal::from_usize(values.len()).ok_or_else(|| { + let n = BigDecimal::from_usize(n).ok_or_else(|| { ShellError::labeled_error( "could not convert to big decimal", "could not convert to big decimal",