From 1d0483c94658963a701bc56f29a209699bf41208 Mon Sep 17 00:00:00 2001 From: Fernando Herrera Date: Mon, 28 Jun 2021 11:17:37 +0100 Subject: [PATCH] Casting operations for Series with differents types (#3702) * Type in command description * filter name change * Clean column name * Clippy error and updated polars version * Lint correction in file * CSV Infer schema optional * Correct float operations * changes in series castings to allow other types * Clippy error correction --- crates/nu-data/src/dataframe.rs | 349 +++++++++++++++++++++----------- 1 file changed, 232 insertions(+), 117 deletions(-) diff --git a/crates/nu-data/src/dataframe.rs b/crates/nu-data/src/dataframe.rs index c880203acb..7995a31464 100644 --- a/crates/nu-data/src/dataframe.rs +++ b/crates/nu-data/src/dataframe.rs @@ -8,10 +8,9 @@ use nu_protocol::{ use nu_source::Span; use num_traits::ToPrimitive; -use num_bigint::BigInt; use polars::prelude::{ BooleanType, ChunkCompare, ChunkedArray, DataType, Float64Type, Int64Type, IntoSeries, - NumOpsDispatchChecked, Series, + NumOpsDispatchChecked, PolarsError, Series, }; use std::ops::{Add, BitAnd, BitOr, Div, Mul, Sub}; @@ -202,19 +201,20 @@ pub fn compute_series_single_value( UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64( lhs.as_ref(), val, - <&ChunkedArray>::add, + >::add, &left.tag.span, )), - UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint( + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_i64( lhs.as_ref(), - val, - <&ChunkedArray>::add, + &val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), + >::add, &left.tag.span, )), UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal( lhs.as_ref(), val, - <&ChunkedArray>::add, + >::add, &left.tag.span, )), _ => Ok(UntaggedValue::Error( @@ -231,19 +231,20 @@ pub fn compute_series_single_value( UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64( lhs.as_ref(), val, - <&ChunkedArray>::sub, + >::sub, &left.tag.span, )), - UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint( + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_i64( lhs.as_ref(), - val, - <&ChunkedArray>::sub, + &val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), + >::sub, &left.tag.span, )), UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal( lhs.as_ref(), val, - <&ChunkedArray>::sub, + >::sub, &left.tag.span, )), _ => Ok(UntaggedValue::Error( @@ -260,19 +261,20 @@ pub fn compute_series_single_value( UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64( lhs.as_ref(), val, - <&ChunkedArray>::mul, + >::mul, &left.tag.span, )), - UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint( + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_i64( lhs.as_ref(), - val, - <&ChunkedArray>::mul, + &val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), + >::mul, &left.tag.span, )), UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal( lhs.as_ref(), val, - <&ChunkedArray>::mul, + >::mul, &left.tag.span, )), _ => Ok(UntaggedValue::Error( @@ -297,7 +299,7 @@ pub fn compute_series_single_value( Ok(compute_series_i64( lhs.as_ref(), val, - <&ChunkedArray>::div, + >::div, &left.tag.span, )) } @@ -310,10 +312,11 @@ pub fn compute_series_single_value( &right.tag.span, ))) } else { - Ok(compute_series_bigint( + Ok(compute_series_i64( lhs.as_ref(), - val, - <&ChunkedArray>::div, + &val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), + >::div, &left.tag.span, )) } @@ -329,7 +332,7 @@ pub fn compute_series_single_value( Ok(compute_series_decimal( lhs.as_ref(), val, - <&ChunkedArray>::div, + >::div, &left.tag.span, )) } @@ -352,9 +355,10 @@ pub fn compute_series_single_value( ChunkedArray::eq, &left.tag.span, )), - UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64( lhs.as_ref(), - val, + &val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), ChunkedArray::eq, &left.tag.span, )), @@ -379,9 +383,10 @@ pub fn compute_series_single_value( ChunkedArray::neq, &left.tag.span, )), - UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64( lhs.as_ref(), - val, + &val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), ChunkedArray::neq, &left.tag.span, )), @@ -409,9 +414,10 @@ pub fn compute_series_single_value( ChunkedArray::lt, &left.tag.span, )), - UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64( lhs.as_ref(), - val, + &val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), ChunkedArray::lt, &left.tag.span, )), @@ -436,9 +442,10 @@ pub fn compute_series_single_value( ChunkedArray::lt_eq, &left.tag.span, )), - UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64( lhs.as_ref(), - val, + &val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), ChunkedArray::lt_eq, &left.tag.span, )), @@ -466,9 +473,10 @@ pub fn compute_series_single_value( ChunkedArray::gt, &left.tag.span, )), - UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64( lhs.as_ref(), - val, + &val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), ChunkedArray::gt, &left.tag.span, )), @@ -493,9 +501,10 @@ pub fn compute_series_single_value( ChunkedArray::gt_eq, &left.tag.span, )), - UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( + UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64( lhs.as_ref(), - val, + &val.to_i64() + .expect("Internal error: protocol did not use compatible decimal"), ChunkedArray::gt_eq, &left.tag.span, )), @@ -540,14 +549,53 @@ pub fn compute_series_single_value( } } -fn compute_series_i64<'r, F>(series: &'r Series, val: &i64, f: F, span: &Span) -> UntaggedValue +fn compute_series_i64(series: &Series, val: &i64, f: F, span: &Span) -> UntaggedValue where - F: Fn(&'r ChunkedArray, i64) -> ChunkedArray, + F: Fn(ChunkedArray, i64) -> ChunkedArray, +{ + match series.dtype() { + DataType::UInt32 | DataType::Int32 | DataType::UInt64 => { + let to_i64 = series.cast_with_dtype(&DataType::Int64); + + match to_i64 { + Ok(series) => { + let casted = series.i64(); + compute_casted_i64(casted, *val, f, span) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } + } + DataType::Int64 => { + let casted = series.i64(); + compute_casted_i64(casted, *val, f, span) + } + _ => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!( + "Series of type {} can not be used for operations with an i64 value", + series.dtype() + ), + span, + )), + } +} + +fn compute_casted_i64( + casted: Result<&ChunkedArray, PolarsError>, + val: i64, + f: F, + span: &Span, +) -> UntaggedValue +where + F: Fn(ChunkedArray, i64) -> ChunkedArray, { - let casted = series.i64(); match casted { Ok(casted) => { - let res = f(casted, *val); + let res = f(casted.clone(), val); let res = res.into_series(); NuSeries::series_to_untagged(res) } @@ -559,98 +607,65 @@ where } } -fn compute_series_bigint<'r, F>( - series: &'r Series, - val: &BigInt, - f: F, - span: &Span, -) -> UntaggedValue +fn compute_series_decimal(series: &Series, val: &BigDecimal, f: F, span: &Span) -> UntaggedValue where - F: Fn(&'r ChunkedArray, i64) -> ChunkedArray, + F: Fn(ChunkedArray, f64) -> ChunkedArray, { - let casted = series.i64(); - match casted { - Ok(casted) => { - let res = f( - casted, - val.to_i64() - .expect("Internal error: protocol did not use compatible decimal"), - ); - let res = res.into_series(); - NuSeries::series_to_untagged(res) - } - Err(e) => UntaggedValue::Error(ShellError::labeled_error( - "Casting error", - format!("{}", e), - span, - )), - } -} + match series.dtype() { + DataType::Float32 => { + let to_f64 = series.cast_with_dtype(&DataType::Float64); -fn compute_series_decimal<'r, F>( - series: &'r Series, - val: &BigDecimal, - f: F, - span: &Span, -) -> UntaggedValue -where - F: Fn(&'r ChunkedArray, f64) -> ChunkedArray, -{ - let casted = series.f64(); - match casted { - Ok(casted) => { - let res = f( + match to_f64 { + Ok(series) => { + let casted = series.f64(); + compute_casted_f64( + casted, + val.to_f64() + .expect("Internal error: protocol did not use compatible decimal"), + f, + span, + ) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } + } + DataType::Float64 => { + let casted = series.f64(); + compute_casted_f64( casted, val.to_f64() .expect("Internal error: protocol did not use compatible decimal"), - ); - let res = res.into_series(); - NuSeries::series_to_untagged(res) + f, + span, + ) } - Err(e) => UntaggedValue::Error(ShellError::labeled_error( + _ => UntaggedValue::Error(ShellError::labeled_error( "Casting error", - format!("{}", e), + format!( + "Series of type {} can not be used for operations with a decimal value", + series.dtype() + ), span, )), } } -fn compare_series_i64<'r, F>(series: &'r Series, val: &i64, f: F, span: &Span) -> UntaggedValue -where - F: Fn(&'r ChunkedArray, i64) -> ChunkedArray, -{ - let casted = series.i64(); - match casted { - Ok(casted) => { - let res = f(casted, *val); - let res = res.into_series(); - NuSeries::series_to_untagged(res) - } - Err(e) => UntaggedValue::Error(ShellError::labeled_error( - "Casting error", - format!("{}", e), - span, - )), - } -} - -fn compare_series_bigint<'r, F>( - series: &'r Series, - val: &BigInt, +fn compute_casted_f64( + casted: Result<&ChunkedArray, PolarsError>, + val: f64, f: F, span: &Span, ) -> UntaggedValue where - F: Fn(&'r ChunkedArray, i64) -> ChunkedArray, + F: Fn(ChunkedArray, f64) -> ChunkedArray, { - let casted = series.i64(); match casted { Ok(casted) => { - let res = f( - casted, - val.to_i64() - .expect("Internal error: protocol did not use compatible decimal"), - ); + let res = f(casted.clone(), val); let res = res.into_series(); NuSeries::series_to_untagged(res) } @@ -662,23 +677,123 @@ where } } -fn compare_series_decimal<'r, F>( - series: &'r Series, - val: &BigDecimal, +fn compare_series_i64(series: &Series, val: &i64, f: F, span: &Span) -> UntaggedValue +where + F: Fn(&ChunkedArray, i64) -> ChunkedArray, +{ + match series.dtype() { + DataType::UInt32 | DataType::Int32 | DataType::UInt64 => { + let to_i64 = series.cast_with_dtype(&DataType::Int64); + + match to_i64 { + Ok(series) => { + let casted = series.i64(); + compare_casted_i64(casted, *val, f, span) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } + } + DataType::Int64 => { + let casted = series.i64(); + compare_casted_i64(casted, *val, f, span) + } + _ => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!( + "Series of type {} can not be used for operations with an i64 value", + series.dtype() + ), + span, + )), + } +} + +fn compare_casted_i64( + casted: Result<&ChunkedArray, PolarsError>, + val: i64, f: F, span: &Span, ) -> UntaggedValue where - F: Fn(&'r ChunkedArray, f64) -> ChunkedArray, + F: Fn(&ChunkedArray, i64) -> ChunkedArray, { - let casted = series.f64(); match casted { Ok(casted) => { - let res = f( + let res = f(casted, val); + let res = res.into_series(); + NuSeries::series_to_untagged(res) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } +} + +fn compare_series_decimal(series: &Series, val: &BigDecimal, f: F, span: &Span) -> UntaggedValue +where + F: Fn(&ChunkedArray, f64) -> ChunkedArray, +{ + match series.dtype() { + DataType::Float32 => { + let to_f64 = series.cast_with_dtype(&DataType::Float64); + + match to_f64 { + Ok(series) => { + let casted = series.f64(); + compare_casted_f64( + casted, + val.to_f64() + .expect("Internal error: protocol did not use compatible decimal"), + f, + span, + ) + } + Err(e) => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!("{}", e), + span, + )), + } + } + DataType::Float64 => { + let casted = series.f64(); + compare_casted_f64( casted, val.to_f64() .expect("Internal error: protocol did not use compatible decimal"), - ); + f, + span, + ) + } + _ => UntaggedValue::Error(ShellError::labeled_error( + "Casting error", + format!( + "Series of type {} can not be used for operations with a decimal value", + series.dtype() + ), + span, + )), + } +} + +fn compare_casted_f64( + casted: Result<&ChunkedArray, PolarsError>, + val: f64, + f: F, + span: &Span, +) -> UntaggedValue +where + F: Fn(&ChunkedArray, f64) -> ChunkedArray, +{ + match casted { + Ok(casted) => { + let res = f(casted, val); let res = res.into_series(); NuSeries::series_to_untagged(res) }