added flags for aggregate and sort for polars pivot

This commit is contained in:
Jack Wright 2024-07-08 09:52:14 -07:00
parent 23ce14f07c
commit fff258a056

View File

@ -4,12 +4,12 @@ use nu_protocol::{
Value, Value,
}; };
use polars_ops::pivot::pivot; use polars_ops::pivot::{pivot, PivotAgg};
use crate::{ use crate::{
dataframe::values::utils::convert_columns_string, dataframe::values::utils::convert_columns_string,
values::{CustomValueSupport, PolarsPluginObject}, values::{CustomValueSupport, PolarsPluginObject},
Cacheable, PolarsPlugin, PolarsPlugin,
}; };
use super::super::values::NuDataFrame; use super::super::values::NuDataFrame;
@ -48,14 +48,25 @@ impl PluginCommand for PivotDF {
"column names used as value columns", "column names used as value columns",
Some('v'), Some('v'),
) )
.input_output_type( .named(
Type::Custom("dataframe".into()), "aggregate",
Type::Custom("dataframe".into()), SyntaxShape::String,
"Aggregation to apply when pivoting. The following are supported: first, sum, min, max, mean, median, count, last",
Some('a'),
)
.switch(
"sort",
"Sort columns",
Some('s'),
) )
.switch( .switch(
"streamable", "streamable",
"Whether or not to use the polars streaming engine. Only valid for lazy dataframes", "Whether or not to use the polars streaming engine. Only valid for lazy dataframes",
Some('s'), Some('t'),
)
.input_output_type(
Type::Custom("dataframe".into()),
Type::Custom("dataframe".into()),
) )
.category(Category::Custom("dataframe".into())) .category(Category::Custom("dataframe".into()))
} }
@ -106,6 +117,13 @@ fn command_eager(
check_column_datatypes(df.as_ref(), &index_col_string, index_col_span)?; check_column_datatypes(df.as_ref(), &index_col_string, index_col_span)?;
check_column_datatypes(df.as_ref(), &val_col_string, val_col_span)?; check_column_datatypes(df.as_ref(), &val_col_string, val_col_span)?;
let aggregate: Option<PivotAgg> = call
.get_flag::<String>("aggregate")?
.map(pivot_agg_for_str)
.transpose()?;
let sort = call.has_flag("sort")?;
let polars_df = df.to_polars(); let polars_df = df.to_polars();
// todo add other args // todo add other args
let pivoted = pivot( let pivoted = pivot(
@ -113,8 +131,8 @@ fn command_eager(
&on_col_string, &on_col_string,
Some(&index_col_string), Some(&index_col_string),
Some(&val_col_string), Some(&val_col_string),
false, sort,
None, aggregate,
None, None,
) )
.map_err(|e| ShellError::GenericError { .map_err(|e| ShellError::GenericError {
@ -186,6 +204,28 @@ fn check_column_datatypes<T: AsRef<str>>(
Ok(()) Ok(())
} }
fn pivot_agg_for_str(agg: impl AsRef<str>) -> Result<PivotAgg, ShellError> {
match agg.as_ref() {
"first" => Ok(PivotAgg::First),
"sum" => Ok(PivotAgg::Sum),
"min" => Ok(PivotAgg::Min),
"max" => Ok(PivotAgg::Max),
"mean" => Ok(PivotAgg::Mean),
"median" => Ok(PivotAgg::Median),
"count" => Ok(PivotAgg::Count),
"last" => Ok(PivotAgg::Last),
s => Err(ShellError::GenericError {
error: format!("{s} is not a valid aggregation"),
msg: "".into(),
span: None,
help: Some(
"Use one of the following: first, sum, min, max, mean, median, count, last".into(),
),
inner: vec![],
}),
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::test::test_polars_plugin_command; use crate::test::test_polars_plugin_command;