diff --git a/Cargo.lock b/Cargo.lock index ea9de5c2a2..1a1f4ca52c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -207,7 +207,7 @@ checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" [[package]] name = "arrow" version = "5.0.0-SNAPSHOT" -source = "git+https://github.com/apache/arrow-rs?rev=f26ffb3091ae355d246edc4a6fcc2c8e5b9bc570#f26ffb3091ae355d246edc4a6fcc2c8e5b9bc570" +source = "git+https://github.com/apache/arrow-rs?rev=0f55b828883b3b3afda43ae404b130d374e6f1a1#0f55b828883b3b3afda43ae404b130d374e6f1a1" dependencies = [ "chrono", "csv", @@ -3504,6 +3504,7 @@ dependencies = [ "num-bigint 0.3.2", "num-format", "num-traits 0.2.14", + "polars", "query_interface", "serde 1.0.126", "sha2 0.9.5", @@ -4362,7 +4363,7 @@ dependencies = [ [[package]] name = "parquet" version = "5.0.0-SNAPSHOT" -source = "git+https://github.com/apache/arrow-rs?rev=f26ffb3091ae355d246edc4a6fcc2c8e5b9bc570#f26ffb3091ae355d246edc4a6fcc2c8e5b9bc570" +source = "git+https://github.com/apache/arrow-rs?rev=0f55b828883b3b3afda43ae404b130d374e6f1a1#0f55b828883b3b3afda43ae404b130d374e6f1a1" dependencies = [ "arrow", "base64 0.13.0", @@ -4600,8 +4601,8 @@ dependencies = [ [[package]] name = "polars" -version = "0.14.0" -source = "git+https://github.com/pola-rs/polars?rev=a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd#a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd" +version = "0.14.1" +source = "git+https://github.com/pola-rs/polars?rev=9e1506cca9fb646fc55f949ab6345290c3d198a7#9e1506cca9fb646fc55f949ab6345290c3d198a7" dependencies = [ "polars-core", "polars-io", @@ -4610,8 +4611,8 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.14.0" -source = "git+https://github.com/pola-rs/polars?rev=a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd#a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd" +version = "0.14.1" +source = "git+https://github.com/pola-rs/polars?rev=9e1506cca9fb646fc55f949ab6345290c3d198a7#9e1506cca9fb646fc55f949ab6345290c3d198a7" dependencies = [ "arrow", "num 0.4.0", @@ -4620,8 +4621,8 @@ dependencies = [ [[package]] name = "polars-core" -version = "0.14.0" -source = "git+https://github.com/pola-rs/polars?rev=a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd#a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd" +version = "0.14.1" +source = "git+https://github.com/pola-rs/polars?rev=9e1506cca9fb646fc55f949ab6345290c3d198a7#9e1506cca9fb646fc55f949ab6345290c3d198a7" dependencies = [ "ahash", "anyhow", @@ -4646,8 +4647,8 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.14.0" -source = "git+https://github.com/pola-rs/polars?rev=a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd#a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd" +version = "0.14.1" +source = "git+https://github.com/pola-rs/polars?rev=9e1506cca9fb646fc55f949ab6345290c3d198a7#9e1506cca9fb646fc55f949ab6345290c3d198a7" dependencies = [ "ahash", "anyhow", @@ -4669,8 +4670,8 @@ dependencies = [ [[package]] name = "polars-lazy" -version = "0.14.0" -source = "git+https://github.com/pola-rs/polars?rev=a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd#a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd" +version = "0.14.1" +source = "git+https://github.com/pola-rs/polars?rev=9e1506cca9fb646fc55f949ab6345290c3d198a7#9e1506cca9fb646fc55f949ab6345290c3d198a7" dependencies = [ "ahash", "itertools", diff --git a/Cargo.toml b/Cargo.toml index 367e830d0d..71d29530b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -156,6 +156,7 @@ table-pager = ["nu-command/table-pager"] #dataframe feature for nushell dataframe = [ + "nu-engine/dataframe", "nu-protocol/dataframe", "nu-command/dataframe", "nu-value-ext/dataframe", diff --git a/crates/nu-command/Cargo.toml b/crates/nu-command/Cargo.toml index af42ee4b65..c032822f5c 100644 --- a/crates/nu-command/Cargo.toml +++ b/crates/nu-command/Cargo.toml @@ -101,10 +101,10 @@ zip = { version = "0.5.9", optional = true } [dependencies.polars] git = "https://github.com/pola-rs/polars" -rev = "a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd" -version = "0.14.0" +rev = "9e1506cca9fb646fc55f949ab6345290c3d198a7" +version = "0.14.1" optional = true -features = ["parquet", "json", "random", "pivot"] +features = ["parquet", "json", "random", "pivot", "strings"] [target.'cfg(unix)'.dependencies] umask = "1.0.0" diff --git a/crates/nu-command/src/commands/dataframe/groupby.rs b/crates/nu-command/src/commands/dataframe/groupby.rs index 80691740b1..80c2072059 100644 --- a/crates/nu-command/src/commands/dataframe/groupby.rs +++ b/crates/nu-command/src/commands/dataframe/groupby.rs @@ -12,7 +12,7 @@ pub struct DataFrame; impl WholeStreamCommand for DataFrame { fn name(&self) -> &str { - "pls groupby" + "pls group-by" } fn usage(&self) -> &str { @@ -20,7 +20,7 @@ impl WholeStreamCommand for DataFrame { } fn signature(&self) -> Signature { - Signature::build("pls groupby").required( + Signature::build("pls group-by").required( "by columns", SyntaxShape::Table, "groupby columns", @@ -34,7 +34,7 @@ impl WholeStreamCommand for DataFrame { fn examples(&self) -> Vec { vec![Example { description: "Grouping by column a", - example: "[[a b]; [one 1] [one 2]] | pls to-df | pls groupby [a]", + example: "[[a b]; [one 1] [one 2]] | pls to-df | pls group-by [a]", result: None, }] } diff --git a/crates/nu-command/src/commands/dataframe/where_.rs b/crates/nu-command/src/commands/dataframe/where_.rs index 4ac7a33fc3..cc5094f064 100644 --- a/crates/nu-command/src/commands/dataframe/where_.rs +++ b/crates/nu-command/src/commands/dataframe/where_.rs @@ -4,11 +4,11 @@ use nu_errors::ShellError; use nu_protocol::{ dataframe::NuDataFrame, hir::{CapturedBlock, ClassifiedCommand, Expression, Literal, Operator, SpannedExpression}, - Primitive, Signature, SyntaxShape, UnspannedPathMember, UntaggedValue, + Primitive, Signature, SyntaxShape, UnspannedPathMember, UntaggedValue, Value, }; use super::utils::parse_polars_error; -use polars::prelude::{ChunkCompare, Series}; +use polars::prelude::{ChunkCompare, DataType, Series}; pub struct DataFrame; @@ -91,22 +91,8 @@ fn command(args: CommandArgs) -> Result { }?; let rhs = evaluate_baseline_expr(&expression.right, &args.args.context)?; - let right_condition = match &rhs.value { - UntaggedValue::Primitive(primitive) => Ok(primitive), - _ => Err(ShellError::labeled_error( - "Incorrect argument", - "Expected primitive values", - &rhs.tag.span, - )), - }?; - filter_dataframe( - args, - &col_name, - &col_name_span, - &right_condition, - &expression.op, - ) + filter_dataframe(args, &col_name, &col_name_span, &rhs, &expression.op) } macro_rules! comparison_arm { @@ -145,16 +131,25 @@ fn filter_dataframe( mut args: EvaluatedCommandArgs, col_name: &str, col_name_span: &Span, - right_condition: &Primitive, + rhs: &Value, operator: &SpannedExpression, ) -> Result { + let right_condition = match &rhs.value { + UntaggedValue::Primitive(primitive) => Ok(primitive), + _ => Err(ShellError::labeled_error( + "Incorrect argument", + "Expected primitive values", + &rhs.tag.span, + )), + }?; + let span = args.call_info.name_tag.span; let df = NuDataFrame::try_from_stream(&mut args.input, &span)?; let col = df .as_ref() .column(col_name) - .map_err(|e| parse_polars_error::<&str>(&e, &col_name_span, None))?; + .map_err(|e| parse_polars_error::<&str>(&e, col_name_span, None))?; let op = match &operator.expr { Expression::Literal(Literal::Operator(op)) => Ok(op), @@ -176,6 +171,33 @@ fn filter_dataframe( Operator::GreaterThanOrEqual => { comparison_arm!(Series::gt_eq, col, right_condition, operator.span) } + Operator::Contains => match col.dtype() { + DataType::Utf8 => match right_condition { + Primitive::String(pat) => { + let casted = col.utf8().map_err(|e| { + parse_polars_error::<&str>(&e, &args.call_info.name_tag.span, None) + })?; + + casted.contains(pat).map_err(|e| { + parse_polars_error::<&str>(&e, &args.call_info.name_tag.span, None) + }) + } + _ => Err(ShellError::labeled_error_with_secondary( + "Incorrect argument", + "Can't perform contains with this value", + &rhs.tag.span, + "Contains only works with strings", + &rhs.tag.span, + )), + }, + _ => Err(ShellError::labeled_error_with_secondary( + "Incorrect datatype", + format!("The selected column is of type '{}'", col.dtype()), + col_name_span, + "Perhaps you want to select a column of 'str' type", + col_name_span, + )), + }, _ => Err(ShellError::labeled_error( "Incorrect operator", "Not implemented operator for dataframes filter", diff --git a/crates/nu-data/Cargo.toml b/crates/nu-data/Cargo.toml index bdeabf7d78..181ad40fcc 100644 --- a/crates/nu-data/Cargo.toml +++ b/crates/nu-data/Cargo.toml @@ -37,10 +37,17 @@ nu-test-support = { version = "0.32.1", path = "../nu-test-support" } nu-value-ext = { version = "0.32.1", path = "../nu-value-ext" } nu-ansi-term = { version = "0.32.1", path = "../nu-ansi-term" } +[dependencies.polars] +git = "https://github.com/pola-rs/polars" +rev = "9e1506cca9fb646fc55f949ab6345290c3d198a7" +version = "0.14.1" +optional = true +features = ["strings", "checked_arithmetic"] + [target.'cfg(unix)'.dependencies] users = "0.11.0" [features] directories = ["directories-next"] dirs = ["dirs-next"] -dataframe = ["nu-protocol/dataframe"] +dataframe = ["nu-protocol/dataframe", "polars"] diff --git a/crates/nu-data/src/dataframe.rs b/crates/nu-data/src/dataframe.rs new file mode 100644 index 0000000000..e778c83fb5 --- /dev/null +++ b/crates/nu-data/src/dataframe.rs @@ -0,0 +1,717 @@ +use bigdecimal::BigDecimal; +use nu_errors::ShellError; +use nu_protocol::hir::Operator; +use nu_protocol::{ + dataframe::{NuSeries, PolarsData}, + Primitive, ShellTypeName, UntaggedValue, Value, +}; +use nu_source::Span; +use num_traits::ToPrimitive; + +use num_bigint::BigInt; +use polars::prelude::{ + BooleanType, ChunkCompare, ChunkedArray, DataType, Float64Type, Int64Type, IntoSeries, + NumOpsDispatchChecked, Series, +}; +use std::ops::{Add, BitAnd, BitOr, Div, Mul, Sub}; + +pub fn compute_between_series( + operator: Operator, + left: &Value, + right: &Value, +) -> Result { + if let ( + UntaggedValue::DataFrame(PolarsData::Series(lhs)), + UntaggedValue::DataFrame(PolarsData::Series(rhs)), + ) = (&left.value, &right.value) + { + if lhs.as_ref().dtype() != rhs.as_ref().dtype() { + return Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Mixed datatypes", + "this datatype does not match the right hand side datatype", + &left.tag.span, + format!( + "Perhaps you want to change this datatype to '{}'", + lhs.as_ref().dtype() + ), + &right.tag.span, + ), + )); + } + + if lhs.as_ref().len() != rhs.as_ref().len() { + return Ok(UntaggedValue::Error(ShellError::labeled_error( + "Different length", + "this column length does not match the right hand column length", + &left.tag.span, + ))); + } + + match operator { + Operator::Plus => { + let mut res = lhs.as_ref() + rhs.as_ref(); + let name = format!("sum_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + Operator::Minus => { + let mut res = lhs.as_ref() - rhs.as_ref(); + let name = format!("sub_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + Operator::Multiply => { + let mut res = lhs.as_ref() * rhs.as_ref(); + let name = format!("mul_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + Operator::Divide => { + let res = lhs.as_ref().checked_div(rhs.as_ref()); + match res { + Ok(mut res) => { + let name = format!("div_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + Err(e) => Ok(UntaggedValue::Error(ShellError::labeled_error( + "Division error", + format!("{}", e), + &left.tag.span, + ))), + } + } + Operator::Equal => { + let mut res = Series::eq(lhs.as_ref(), rhs.as_ref()).into_series(); + let name = format!("eq_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + Operator::NotEqual => { + let mut res = Series::neq(lhs.as_ref(), rhs.as_ref()).into_series(); + let name = format!("neq_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + Operator::LessThan => { + let mut res = Series::lt(lhs.as_ref(), rhs.as_ref()).into_series(); + let name = format!("lt_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + Operator::LessThanOrEqual => { + let mut res = Series::lt_eq(lhs.as_ref(), rhs.as_ref()).into_series(); + let name = format!("lte_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + Operator::GreaterThan => { + let mut res = Series::gt(lhs.as_ref(), rhs.as_ref()).into_series(); + let name = format!("gt_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + Operator::GreaterThanOrEqual => { + let mut res = Series::gt_eq(lhs.as_ref(), rhs.as_ref()).into_series(); + let name = format!("gte_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + Operator::And => match lhs.as_ref().dtype() { + DataType::Boolean => { + let lhs_cast = lhs.as_ref().bool(); + let rhs_cast = rhs.as_ref().bool(); + + match (lhs_cast, rhs_cast) { + (Ok(l), Ok(r)) => { + let mut res = l.bitand(r).into_series(); + let name = + format!("and_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Casting error", + "unable to cast to boolean", + &left.tag.span, + "unable to cast to boolean", + &right.tag.span, + ), + )), + } + } + _ => Ok(UntaggedValue::Error(ShellError::labeled_error( + "Incorrect datatype", + "And operation can only be done with boolean values", + &left.tag.span, + ))), + }, + Operator::Or => match lhs.as_ref().dtype() { + DataType::Boolean => { + let lhs_cast = lhs.as_ref().bool(); + let rhs_cast = rhs.as_ref().bool(); + + match (lhs_cast, rhs_cast) { + (Ok(l), Ok(r)) => { + let mut res = l.bitor(r).into_series(); + let name = + format!("or_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); + res.rename(name.as_ref()); + Ok(NuSeries::series_to_untagged(res)) + } + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Casting error", + "unable to cast to boolean", + &left.tag.span, + "unable to cast to boolean", + &right.tag.span, + ), + )), + } + } + _ => Ok(UntaggedValue::Error(ShellError::labeled_error( + "Incorrect datatype", + "And operation can only be done with boolean values", + &left.tag.span, + ))), + }, + _ => Ok(UntaggedValue::Error(ShellError::labeled_error( + "Incorrect datatype", + "unable to use this datatype for this operation", + &left.tag.span, + ))), + } + } else { + Err((left.type_name(), right.type_name())) + } +} + +pub fn compute_series_single_value( + operator: Operator, + left: &Value, + right: &Value, +) -> Result { + if let (UntaggedValue::DataFrame(PolarsData::Series(lhs)), UntaggedValue::Primitive(_)) = + (&left.value, &right.value) + { + match operator { + Operator::Plus => match &right.value { + UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64( + lhs.as_ref(), + val, + <&ChunkedArray>::add, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint( + lhs.as_ref(), + val, + <&ChunkedArray>::add, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal( + lhs.as_ref(), + val, + <&ChunkedArray>::add, + &left.tag.span, + )), + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Operation unavailable", + "unable to sum this value to the series", + &right.tag.span, + "Only int, bigInt or decimal values are allowed", + &right.tag.span, + ), + )), + }, + Operator::Minus => match &right.value { + UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64( + lhs.as_ref(), + val, + <&ChunkedArray>::sub, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint( + lhs.as_ref(), + val, + <&ChunkedArray>::sub, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal( + lhs.as_ref(), + val, + <&ChunkedArray>::sub, + &left.tag.span, + )), + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Operation unavailable", + "unable to subtract this value to the series", + &right.tag.span, + "Only int, bigInt or decimal values are allowed", + &right.tag.span, + ), + )), + }, + Operator::Multiply => match &right.value { + UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64( + lhs.as_ref(), + val, + <&ChunkedArray>::mul, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint( + lhs.as_ref(), + val, + <&ChunkedArray>::mul, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal( + lhs.as_ref(), + val, + <&ChunkedArray>::mul, + &left.tag.span, + )), + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Operation unavailable", + "unable to multiply this value to the series", + &right.tag.span, + "Only int, bigInt or decimal values are allowed", + &right.tag.span, + ), + )), + }, + Operator::Divide => match &right.value { + UntaggedValue::Primitive(Primitive::Int(val)) => { + if *val == 0 { + Ok(UntaggedValue::Error(ShellError::labeled_error( + "Division by zero", + "Zero value found", + &right.tag.span, + ))) + } else { + Ok(compute_series_i64( + lhs.as_ref(), + val, + <&ChunkedArray>::div, + &left.tag.span, + )) + } + } + UntaggedValue::Primitive(Primitive::BigInt(val)) => { + if val.eq(&0.into()) { + Ok(UntaggedValue::Error(ShellError::labeled_error( + "Division by zero", + "Zero value found", + &right.tag.span, + ))) + } else { + Ok(compute_series_bigint( + lhs.as_ref(), + val, + <&ChunkedArray>::div, + &left.tag.span, + )) + } + } + UntaggedValue::Primitive(Primitive::Decimal(val)) => { + if val.eq(&0.into()) { + Ok(UntaggedValue::Error(ShellError::labeled_error( + "Division by zero", + "Zero value found", + &right.tag.span, + ))) + } else { + Ok(compute_series_decimal( + lhs.as_ref(), + val, + <&ChunkedArray>::div, + &left.tag.span, + )) + } + } + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Operation unavailable", + "unable to divide this value to the series", + &right.tag.span, + "Only primary values are allowed", + &right.tag.span, + ), + )), + }, + Operator::Equal => { + match &right.value { + UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64( + lhs.as_ref(), + val, + ChunkedArray::eq, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + lhs.as_ref(), + val, + ChunkedArray::eq, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok( + compare_series_decimal(lhs.as_ref(), val, ChunkedArray::eq, &left.tag.span), + ), + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Operation unavailable", + "unable to compare this value to the series", + &right.tag.span, + "Only primary values are allowed", + &right.tag.span, + ), + )), + } + } + Operator::NotEqual => match &right.value { + UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64( + lhs.as_ref(), + val, + ChunkedArray::neq, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + lhs.as_ref(), + val, + ChunkedArray::neq, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compare_series_decimal( + lhs.as_ref(), + val, + ChunkedArray::neq, + &left.tag.span, + )), + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Operation unavailable", + "unable to compare this value to the series", + &right.tag.span, + "Only primary values are allowed", + &right.tag.span, + ), + )), + }, + Operator::LessThan => { + match &right.value { + UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64( + lhs.as_ref(), + val, + ChunkedArray::lt, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + lhs.as_ref(), + val, + ChunkedArray::lt, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok( + compare_series_decimal(lhs.as_ref(), val, ChunkedArray::lt, &left.tag.span), + ), + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Operation unavailable", + "unable to compare this value to the series", + &right.tag.span, + "Only primary values are allowed", + &right.tag.span, + ), + )), + } + } + Operator::LessThanOrEqual => match &right.value { + UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64( + lhs.as_ref(), + val, + ChunkedArray::lt_eq, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + lhs.as_ref(), + val, + ChunkedArray::lt_eq, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compare_series_decimal( + lhs.as_ref(), + val, + ChunkedArray::lt_eq, + &left.tag.span, + )), + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Operation unavailable", + "unable to compare this value to the series", + &right.tag.span, + "Only primary values are allowed", + &right.tag.span, + ), + )), + }, + Operator::GreaterThan => { + match &right.value { + UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64( + lhs.as_ref(), + val, + ChunkedArray::gt, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + lhs.as_ref(), + val, + ChunkedArray::gt, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok( + compare_series_decimal(lhs.as_ref(), val, ChunkedArray::gt, &left.tag.span), + ), + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Operation unavailable", + "unable to compare this value to the series", + &right.tag.span, + "Only primary values are allowed", + &right.tag.span, + ), + )), + } + } + Operator::GreaterThanOrEqual => match &right.value { + UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64( + lhs.as_ref(), + val, + ChunkedArray::gt_eq, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + lhs.as_ref(), + val, + ChunkedArray::gt_eq, + &left.tag.span, + )), + UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compare_series_decimal( + lhs.as_ref(), + val, + ChunkedArray::gt_eq, + &left.tag.span, + )), + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Operation unavailable", + "unable to compare this value to the series", + &right.tag.span, + "Only primary values are allowed", + &right.tag.span, + ), + )), + }, + Operator::Contains => match &right.value { + UntaggedValue::Primitive(Primitive::String(val)) => { + Ok(contains_series_pat(lhs.as_ref(), val, &left.tag.span)) + } + _ => Ok(UntaggedValue::Error( + ShellError::labeled_error_with_secondary( + "Operation unavailable", + "unable to perform this value to the series", + &right.tag.span, + "Only primary values are allowed", + &right.tag.span, + ), + )), + }, + _ => Ok(UntaggedValue::Error(ShellError::labeled_error( + "Incorrect datatype", + "unable to use this value for this operation", + &left.tag.span, + ))), + } + } else { + Err((left.type_name(), right.type_name())) + } +} + +fn compute_series_i64<'r, F>(series: &'r Series, val: &i64, f: F, span: &Span) -> UntaggedValue +where + F: Fn(&'r ChunkedArray, i64) -> ChunkedArray, +{ + let casted = series.i64(); + match casted { + Ok(casted) => { + let res = f(casted, *val); + let res = res.into_series(); + NuSeries::series_to_untagged(res) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } +} + +fn compute_series_bigint<'r, F>( + series: &'r Series, + val: &BigInt, + f: F, + span: &Span, +) -> UntaggedValue +where + F: Fn(&'r ChunkedArray, i64) -> ChunkedArray, +{ + let casted = series.i64(); + match casted { + Ok(casted) => { + let res = f( + casted, + val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), + ); + let res = res.into_series(); + NuSeries::series_to_untagged(res) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } +} + +fn compute_series_decimal<'r, F>( + series: &'r Series, + val: &BigDecimal, + f: F, + span: &Span, +) -> UntaggedValue +where + F: Fn(&'r ChunkedArray, f64) -> ChunkedArray, +{ + let casted = series.f64(); + match casted { + Ok(casted) => { + let res = f( + casted, + val.to_f64() + .expect("Internal error: protocol did not use compatible decimal"), + ); + let res = res.into_series(); + NuSeries::series_to_untagged(res) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } +} + +fn compare_series_i64<'r, F>(series: &'r Series, val: &i64, f: F, span: &Span) -> UntaggedValue +where + F: Fn(&'r ChunkedArray, i64) -> ChunkedArray, +{ + let casted = series.i64(); + match casted { + Ok(casted) => { + let res = f(casted, *val); + let res = res.into_series(); + NuSeries::series_to_untagged(res) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } +} + +fn compare_series_bigint<'r, F>( + series: &'r Series, + val: &BigInt, + f: F, + span: &Span, +) -> UntaggedValue +where + F: Fn(&'r ChunkedArray, i64) -> ChunkedArray, +{ + let casted = series.i64(); + match casted { + Ok(casted) => { + let res = f( + casted, + val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), + ); + let res = res.into_series(); + NuSeries::series_to_untagged(res) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } +} + +fn compare_series_decimal<'r, F>( + series: &'r Series, + val: &BigDecimal, + f: F, + span: &Span, +) -> UntaggedValue +where + F: Fn(&'r ChunkedArray, i64) -> ChunkedArray, +{ + let casted = series.f64(); + match casted { + Ok(casted) => { + let res = f( + casted, + val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), + ); + let res = res.into_series(); + NuSeries::series_to_untagged(res) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } +} + +fn contains_series_pat(series: &Series, pat: &str, span: &Span) -> UntaggedValue { + let casted = series.utf8(); + match casted { + Ok(casted) => { + let res = casted.contains(pat); + + match res { + Ok(res) => { + let res = res.into_series(); + NuSeries::series_to_untagged(res) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Search error", + format!("{}", e), + span, + )), + } + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } +} diff --git a/crates/nu-data/src/lib.rs b/crates/nu-data/src/lib.rs index feb179e5c6..f37ce7bfaf 100644 --- a/crates/nu-data/src/lib.rs +++ b/crates/nu-data/src/lib.rs @@ -7,4 +7,7 @@ pub mod types; pub mod utils; pub mod value; +#[cfg(feature = "dataframe")] +pub mod dataframe; + pub use dict::TaggedListBuilder; diff --git a/crates/nu-data/src/value.rs b/crates/nu-data/src/value.rs index 80c1fdd9e1..dca645d721 100644 --- a/crates/nu-data/src/value.rs +++ b/crates/nu-data/src/value.rs @@ -12,9 +12,6 @@ use num_bigint::BigInt; use num_traits::{ToPrimitive, Zero}; use std::collections::HashMap; -#[cfg(feature = "dataframe")] -use nu_protocol::dataframe::{NuSeries, PolarsData}; - pub struct Date; impl Date { @@ -494,51 +491,6 @@ pub fn compute_values( } _ => Err((left.type_name(), right.type_name())), }, - #[cfg(feature = "dataframe")] - ( - UntaggedValue::DataFrame(PolarsData::Series(lhs)), - UntaggedValue::DataFrame(PolarsData::Series(rhs)), - ) => { - if lhs.as_ref().dtype() == rhs.as_ref().dtype() { - let result = match operator { - Operator::Plus => { - let mut res = lhs.as_ref() + rhs.as_ref(); - let name = format!("sum_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); - let res = res.rename(name.as_ref()); - Ok(res.clone()) - } - Operator::Minus => { - let mut res = lhs.as_ref() - rhs.as_ref(); - let name = format!("sub_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); - let res = res.rename(name.as_ref()); - Ok(res.clone()) - } - Operator::Multiply => { - let mut res = lhs.as_ref() * rhs.as_ref(); - let name = format!("mul_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); - let res = res.rename(name.as_ref()); - Ok(res.clone()) - } - Operator::Divide => { - let mut res = lhs.as_ref() / rhs.as_ref(); - let name = format!("div_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); - let res = res.rename(name.as_ref()); - Ok(res.clone()) - } - Operator::Modulo => { - let mut res = lhs.as_ref() % rhs.as_ref(); - let name = format!("mod_{}_{}", lhs.as_ref().name(), rhs.as_ref().name()); - let res = res.rename(name.as_ref()); - Ok(res.clone()) - } - _ => Err((left.type_name(), right.type_name())), - }?; - - Ok(NuSeries::series_to_untagged(result)) - } else { - Err((left.type_name(), right.type_name())) - } - } _ => Err((left.type_name(), right.type_name())), } } diff --git a/crates/nu-engine/Cargo.toml b/crates/nu-engine/Cargo.toml index 80a00fcaf6..61f0bf6957 100644 --- a/crates/nu-engine/Cargo.toml +++ b/crates/nu-engine/Cargo.toml @@ -65,3 +65,4 @@ hamcrest2 = "0.3.0" rustyline-support = [] dirs = ["dirs-next"] trash-support = ["trash"] +dataframe = ["nu-protocol/dataframe"] diff --git a/crates/nu-engine/src/evaluate/operator.rs b/crates/nu-engine/src/evaluate/operator.rs index e27964d477..c7b0e3d70a 100644 --- a/crates/nu-engine/src/evaluate/operator.rs +++ b/crates/nu-engine/src/evaluate/operator.rs @@ -4,11 +4,29 @@ use nu_protocol::hir::Operator; use nu_protocol::{Primitive, ShellTypeName, UntaggedValue, Value}; use std::ops::Not; +#[cfg(feature = "dataframe")] +use nu_data::dataframe::{compute_between_series, compute_series_single_value}; +#[cfg(feature = "dataframe")] +use nu_protocol::dataframe::PolarsData; + pub fn apply_operator( op: Operator, left: &Value, right: &Value, ) -> Result { + #[cfg(feature = "dataframe")] + if let ( + UntaggedValue::DataFrame(PolarsData::Series(_)), + UntaggedValue::DataFrame(PolarsData::Series(_)), + ) = (&left.value, &right.value) + { + return compute_between_series(op, left, right); + } else if let (UntaggedValue::DataFrame(PolarsData::Series(_)), UntaggedValue::Primitive(_)) = + (&left.value, &right.value) + { + return compute_series_single_value(op, left, right); + } + match op { Operator::Equal | Operator::NotEqual diff --git a/crates/nu-protocol/Cargo.toml b/crates/nu-protocol/Cargo.toml index a9b02e4fb6..9852f21d27 100644 --- a/crates/nu-protocol/Cargo.toml +++ b/crates/nu-protocol/Cargo.toml @@ -32,10 +32,10 @@ toml = "0.5.8" [dependencies.polars] git = "https://github.com/pola-rs/polars" -rev = "a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd" -version = "0.14.0" +rev = "9e1506cca9fb646fc55f949ab6345290c3d198a7" +version = "0.14.1" optional = true -features = ["serde"] +features = ["serde", "rows"] [features] dataframe = ["polars"] diff --git a/crates/nu-protocol/src/dataframe/nu_dataframe.rs b/crates/nu-protocol/src/dataframe/nu_dataframe.rs index 319c811443..6a6b22dd65 100644 --- a/crates/nu-protocol/src/dataframe/nu_dataframe.rs +++ b/crates/nu-protocol/src/dataframe/nu_dataframe.rs @@ -185,7 +185,7 @@ impl NuDataFrame { } pub fn to_rows(&self, from_row: usize, to_row: usize) -> Result, ShellError> { - let df = &self.as_ref(); + let df = self.as_ref(); let column_names = df.get_column_names(); let mut values: Vec = Vec::new();