added flags for aggregate and sort for polars pivot

This commit is contained in:
Jack Wright 2024-07-08 09:52:14 -07:00
parent 23ce14f07c
commit fff258a056

View File

@ -4,12 +4,12 @@ use nu_protocol::{
Value,
};
use polars_ops::pivot::pivot;
use polars_ops::pivot::{pivot, PivotAgg};
use crate::{
dataframe::values::utils::convert_columns_string,
values::{CustomValueSupport, PolarsPluginObject},
Cacheable, PolarsPlugin,
PolarsPlugin,
};
use super::super::values::NuDataFrame;
@ -48,14 +48,25 @@ impl PluginCommand for PivotDF {
"column names used as value columns",
Some('v'),
)
.input_output_type(
Type::Custom("dataframe".into()),
Type::Custom("dataframe".into()),
.named(
"aggregate",
SyntaxShape::String,
"Aggregation to apply when pivoting. The following are supported: first, sum, min, max, mean, median, count, last",
Some('a'),
)
.switch(
"sort",
"Sort columns",
Some('s'),
)
.switch(
"streamable",
"Whether or not to use the polars streaming engine. Only valid for lazy dataframes",
Some('s'),
Some('t'),
)
.input_output_type(
Type::Custom("dataframe".into()),
Type::Custom("dataframe".into()),
)
.category(Category::Custom("dataframe".into()))
}
@ -106,6 +117,13 @@ fn command_eager(
check_column_datatypes(df.as_ref(), &index_col_string, index_col_span)?;
check_column_datatypes(df.as_ref(), &val_col_string, val_col_span)?;
let aggregate: Option<PivotAgg> = call
.get_flag::<String>("aggregate")?
.map(pivot_agg_for_str)
.transpose()?;
let sort = call.has_flag("sort")?;
let polars_df = df.to_polars();
// todo add other args
let pivoted = pivot(
@ -113,8 +131,8 @@ fn command_eager(
&on_col_string,
Some(&index_col_string),
Some(&val_col_string),
false,
None,
sort,
aggregate,
None,
)
.map_err(|e| ShellError::GenericError {
@ -186,6 +204,28 @@ fn check_column_datatypes<T: AsRef<str>>(
Ok(())
}
fn pivot_agg_for_str(agg: impl AsRef<str>) -> Result<PivotAgg, ShellError> {
match agg.as_ref() {
"first" => Ok(PivotAgg::First),
"sum" => Ok(PivotAgg::Sum),
"min" => Ok(PivotAgg::Min),
"max" => Ok(PivotAgg::Max),
"mean" => Ok(PivotAgg::Mean),
"median" => Ok(PivotAgg::Median),
"count" => Ok(PivotAgg::Count),
"last" => Ok(PivotAgg::Last),
s => Err(ShellError::GenericError {
error: format!("{s} is not a valid aggregation"),
msg: "".into(),
span: None,
help: Some(
"Use one of the following: first, sum, min, max, mean, median, count, last".into(),
),
inner: vec![],
}),
}
}
#[cfg(test)]
mod test {
use crate::test::test_polars_plugin_command;