diff --git a/src/commands/post.rs b/src/commands/post.rs index 5533e30652..5a77afd14b 100644 --- a/src/commands/post.rs +++ b/src/commands/post.rs @@ -11,6 +11,11 @@ use std::path::PathBuf; use std::str::FromStr; use surf::mime; +pub enum HeaderKind { + ContentType(String), + ContentLength(String), +} + pub struct Post; impl PerItemCommand for Post { @@ -24,6 +29,8 @@ impl PerItemCommand for Post { .required("body", SyntaxShape::Any) .named("user", SyntaxShape::Any) .named("password", SyntaxShape::Any) + .named("content-type", SyntaxShape::Any) + .named("content-length", SyntaxShape::Any) .switch("raw") } @@ -73,9 +80,11 @@ fn run( let registry = registry.clone(); let raw_args = raw_args.clone(); + let headers = get_headers(&call_info)?; + let stream = async_stream! { let (file_extension, contents, contents_tag, anchor_location) = - post(&path_str, &body, user, password, path_span, ®istry, &raw_args).await.unwrap(); + post(&path_str, &body, user, password, &headers, path_span, ®istry, &raw_args).await.unwrap(); let file_extension = if has_raw { None @@ -138,11 +147,67 @@ fn run( Ok(stream.to_output_stream()) } +fn get_headers(call_info: &CallInfo) -> Result, ShellError> { + let mut headers = vec![]; + + match extract_header_value(&call_info, "content-type") { + Ok(h) => match h { + Some(ct) => headers.push(HeaderKind::ContentType(ct)), + None => {} + }, + Err(e) => { + return Err(e); + } + }; + + match extract_header_value(&call_info, "content-length") { + Ok(h) => match h { + Some(cl) => headers.push(HeaderKind::ContentLength(cl)), + None => {} + }, + Err(e) => { + return Err(e); + } + }; + + Ok(headers) +} + +fn extract_header_value(call_info: &CallInfo, key: &str) -> Result, ShellError> { + if call_info.args.has(key) { + let tagged = call_info.args.get(key); + let val = match tagged { + Some(Tagged { + item: Value::Primitive(Primitive::String(s)), + .. + }) => s.clone(), + Some(Tagged { tag, .. }) => { + return Err(ShellError::labeled_error( + format!("{} not in expected format. Expected string.", key), + "post error", + tag, + )); + } + _ => { + return Err(ShellError::labeled_error( + format!("{} not in expected format. Expected string.", key), + "post error", + Tag::unknown(), + )); + } + }; + return Ok(Some(val)); + } + + Ok(None) +} + pub async fn post( location: &str, body: &Tagged, user: Option, password: Option, + headers: &Vec, tag: Tag, registry: &CommandRegistry, raw_args: &RawCommandArgs, @@ -164,6 +229,13 @@ pub async fn post( if let Some(login) = login { s = s.set_header("Authorization", format!("Basic {}", login)); } + + for h in headers { + s = match h { + HeaderKind::ContentType(ct) => s.set_header("Content-Type", ct), + HeaderKind::ContentLength(cl) => s.set_header("Content-Length", cl), + }; + } s.await } Tagged {