diff --git a/crates/nu-protocol/src/value/did_you_mean.rs b/crates/nu-protocol/src/value/did_you_mean.rs index bd2eb7a548..2be9282791 100644 --- a/crates/nu-protocol/src/value/did_you_mean.rs +++ b/crates/nu-protocol/src/value/did_you_mean.rs @@ -24,28 +24,28 @@ pub fn did_you_mean(obj_source: &Value, field_tried: String) -> Option usize { - let n = str1.len(); - let m = str2.len(); + let mut current: Vec = (0..str1.len() + 1).collect(); + let str1_chars: Vec = str1.chars().collect(); + let str2_chars: Vec = str2.chars().collect(); - let mut current: Vec = (0..n + 1).collect(); - let a_vec: Vec = str1.chars().collect(); - let b_vec: Vec = str2.chars().collect(); + let str1_len = str1_chars.len(); + let str2_len = str2_chars.len(); - for i in 1..m + 1 { + for str2_index in 1..str2_len + 1 { let previous = current; - current = vec![0; n + 1]; - current[0] = i; - for j in 1..n + 1 { - let add = previous[j] + 1; - let delete = current[j - 1] + 1; - let mut change = previous[j - 1]; - if a_vec[j - 1] != b_vec[i - 1] { + current = vec![0; str1_len + 1]; + current[0] = str2_index; + for str1_index in 1..str1_len + 1 { + let add = previous[str1_index] + 1; + let delete = current[str1_index - 1] + 1; + let mut change = previous[str1_index - 1]; + if str1_chars[str1_index - 1] != str2_chars[str2_index - 1] { change += 1 } - current[j] = min3(add, delete, change); + current[str1_index] = min3(add, delete, change); } } - current[n] + current[str1_len] } fn min3(a: T, b: T, c: T) -> T { @@ -91,4 +91,12 @@ mod test { assert_eq!(None, did_you_mean(&empty_source, "hat".to_string())) } + + #[test] + fn test_levenshtein_distance() { + assert_eq!(super::levenshtein_distance("hello world", "hello world"), 0); + assert_eq!(super::levenshtein_distance("hello", "hello world"), 6); + assert_eq!(super::levenshtein_distance("°C", "°C"), 0); + assert_eq!(super::levenshtein_distance("°", "°C"), 1); + } }