diff --git a/mps-interpreter/README.md b/mps-interpreter/README.md index f6cf072..66fe430 100644 --- a/mps-interpreter/README.md +++ b/mps-interpreter/README.md @@ -141,6 +141,10 @@ Combine multiple iterables in an interleaved pattern. This is a variant of union Combine multiple iterables in a sequential pattern. All items in iterable1 are returned, then all items in iterable2, ... until all provided iterables are depleted. There is no limit on the amount of iterables which can be provided as parameters. +#### intersection(iterable1, iterable2, ...); + +Combine multiple iterables such that only items that exist in iterable1 and iterable2 and ... are returned. The order of items from iterable1 is maintained. There is no limit on the amount of iterables which can be provided as parameters. + #### empty(); Empty iterator. Useful for deleting items using replacement filters. diff --git a/mps-interpreter/src/interpretor.rs b/mps-interpreter/src/interpretor.rs index 6f2db4c..e63c157 100644 --- a/mps-interpreter/src/interpretor.rs +++ b/mps-interpreter/src/interpretor.rs @@ -175,5 +175,6 @@ pub(crate) fn standard_vocab(vocabulary: &mut MpsLanguageDictionary) { .add(crate::lang::vocabulary::files_function_factory()) .add(crate::lang::vocabulary::empty_function_factory()) .add(crate::lang::vocabulary::reset_function_factory()) - .add(crate::lang::vocabulary::union_function_factory()); + .add(crate::lang::vocabulary::union_function_factory()) + .add(crate::lang::vocabulary::intersection_function_factory()); } diff --git a/mps-interpreter/src/item.rs b/mps-interpreter/src/item.rs index 4a60bbc..9347f3a 100644 --- a/mps-interpreter/src/item.rs +++ b/mps-interpreter/src/item.rs @@ -40,6 +40,41 @@ impl Display for MpsItem { } } +impl std::hash::Hash for MpsItem { + fn hash(&self, state: &mut H) where H: std::hash::Hasher { + // hashing is order-dependent, so the pseudo-random sorting of HashMap keys + // prevents it from working correctly without sorting + let mut keys: Vec<_> = self.fields.keys().collect(); + keys.as_mut_slice().sort(); + for key in keys { + let val = self.fields.get(key).unwrap(); + key.hash(state); + val.hash(state); + } + } +} + +impl std::cmp::PartialEq for MpsItem { + /*fn eq(&self, other: &Self) -> bool { + for (key, val) in self.fields.iter() { + if let Some(other_val) = other.fields.get(key) { + if other_val != val { + return false; + } + } else { + return false; + } + } + true + }*/ + + fn eq(&self, other: &Self) -> bool { + self.fields == other.fields + } +} + +impl std::cmp::Eq for MpsItem {} + /*pub(crate) trait MpsItemRuntimeUtil { fn get_field_runtime(&self, name: &str, op: &mut OpGetter) -> Result<&MpsTypePrimitive, RuntimeError>; } diff --git a/mps-interpreter/src/lang/error.rs b/mps-interpreter/src/lang/error.rs index 09d78a1..232e62d 100644 --- a/mps-interpreter/src/lang/error.rs +++ b/mps-interpreter/src/lang/error.rs @@ -52,6 +52,21 @@ impl Display for RuntimeError { } } +impl std::hash::Hash for RuntimeError { + fn hash(&self, state: &mut H) where H: std::hash::Hasher { + self.line.hash(state); + self.msg.hash(state); + } +} + +impl std::cmp::PartialEq for RuntimeError { + fn eq(&self, other: &Self) -> bool { + self.line == other.line && self.msg == other.msg + } +} + +impl std::cmp::Eq for RuntimeError {} + impl MpsLanguageError for RuntimeError { fn set_line(&mut self, line: usize) { self.line = line @@ -63,7 +78,7 @@ pub trait MpsLanguageError: Display + Debug { } // RuntimeError builder components -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Hash)] pub struct RuntimeMsg(pub String); impl RuntimeMsg { diff --git a/mps-interpreter/src/lang/type_primitives.rs b/mps-interpreter/src/lang/type_primitives.rs index 9f33394..733d3b0 100644 --- a/mps-interpreter/src/lang/type_primitives.rs +++ b/mps-interpreter/src/lang/type_primitives.rs @@ -144,6 +144,18 @@ impl PartialOrd for MpsTypePrimitive { } } +impl std::hash::Hash for MpsTypePrimitive { + fn hash(&self, state: &mut H) where H: std::hash::Hasher { + match self { + Self::String(s) => s.hash(state), + Self::Int(i) => i.hash(state), + Self::UInt(u) => u.hash(state), + Self::Float(f_) => (*f_ as u64).hash(state), + Self::Bool(b) => b.hash(state), + } + } +} + #[inline] fn map_ordering(ordering: std::cmp::Ordering) -> i8 { match ordering { diff --git a/mps-interpreter/src/lang/vocabulary/intersection.rs b/mps-interpreter/src/lang/vocabulary/intersection.rs new file mode 100644 index 0000000..a9369ec --- /dev/null +++ b/mps-interpreter/src/lang/vocabulary/intersection.rs @@ -0,0 +1,167 @@ +use std::collections::{VecDeque, HashSet}; +use std::fmt::{Debug, Display, Error, Formatter}; +use std::iter::Iterator; + +use crate::tokens::MpsToken; +use crate::MpsContext; + +use crate::lang::{MpsLanguageDictionary, PseudoOp}; +use crate::lang::{MpsFunctionFactory, MpsFunctionStatementFactory, MpsIteratorItem, MpsOp}; +use crate::lang::{RuntimeError, SyntaxError}; +use crate::lang::repeated_tokens; +use crate::lang::vocabulary::union::next_comma; + +#[derive(Debug)] +pub struct IntersectionStatement { + context: Option, + ops: Vec, + items: Option>, + original_order: Option>, + init_needed: bool, +} + +impl Display for IntersectionStatement { + fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { + let mut ops_str = "".to_owned(); + for i in 0..self.ops.len() { + ops_str += &self.ops[i].to_string(); + if i != self.ops.len() - 1 { + ops_str += ", "; + } + } + write!(f, "intersection({})", ops_str) + } +} + +impl std::clone::Clone for IntersectionStatement { + fn clone(&self) -> Self { + Self { + context: None, + ops: self.ops.clone(), + items: None, + original_order: None, + init_needed: self.init_needed, + } + } +} + +impl Iterator for IntersectionStatement { + type Item = MpsIteratorItem; + + fn next(&mut self) -> Option { + if self.ops.len() == 0 { + return None; + } else if self.init_needed { + self.init_needed = false; + let real_op = match self.ops[0].try_real() { + Ok(op) => op, + Err(e) => return Some(Err(e)), + }; + real_op.enter(self.context.take().unwrap()); + let original_order: VecDeque = real_op.collect(); + let mut set: HashSet = original_order.iter().map(|x| x.to_owned()).collect(); + self.context = Some(real_op.escape()); + if self.ops.len() != 1 && !set.is_empty() { + for i in 1..self.ops.len() { + let real_op = match self.ops[i].try_real() { + Ok(op) => op, + Err(e) => return Some(Err(e)), + }; + real_op.enter(self.context.take().unwrap()); + let set2: HashSet = real_op.collect(); + self.context = Some(real_op.escape()); + set.retain(|item| set2.contains(item)); + } + } + self.original_order = Some(original_order); + self.items = Some(set); + self.init_needed = false; + } + let original_order = self.original_order.as_mut().unwrap(); + let set_items = self.items.as_ref().unwrap(); + while let Some(item) = original_order.pop_front() { + if set_items.contains(&item) { + return Some(item); + } + } + None + } + + fn size_hint(&self) -> (usize, Option) { + (0, None) + } +} + +impl MpsOp for IntersectionStatement { + fn enter(&mut self, ctx: MpsContext) { + self.context = Some(ctx) + } + + fn escape(&mut self) -> MpsContext { + self.context.take().unwrap() + } + + fn is_resetable(&self) -> bool { + true + } + + fn reset(&mut self) -> Result<(), RuntimeError> { + self.init_needed = true; + self.original_order = None; + self.items = None; + for op in &mut self.ops { + let real_op = op.try_real()?; + real_op.enter(self.context.take().unwrap()); + if real_op.is_resetable() { + let result = real_op.reset(); + self.context = Some(real_op.escape()); + result?; + } else { + self.context = Some(real_op.escape()); + } + + } + Ok(()) + } +} + +pub struct IntersectionFunctionFactory; + +impl MpsFunctionFactory for IntersectionFunctionFactory { + fn is_function(&self, name: &str) -> bool { + name == "intersection" || name == "n" + } + + fn build_function_params( + &self, + _name: String, + tokens: &mut VecDeque, + dict: &MpsLanguageDictionary, + ) -> Result { + // intersection(op1, op2, ...) + let operations = repeated_tokens(|tokens| { + if let Some(comma_pos) = next_comma(tokens) { + let end_tokens = tokens.split_off(comma_pos); + let op = dict.try_build_statement(tokens); + tokens.extend(end_tokens); + Ok(Some(PseudoOp::from(op?))) + } else { + Ok(Some(PseudoOp::from(dict.try_build_statement(tokens)?))) + } + }, MpsToken::Comma).ingest_all(tokens)?; + Ok(IntersectionStatement { + context: None, + ops: operations, + items: None, + original_order: None, + init_needed: true, + }) + } +} + +pub type IntersectionStatementFactory = MpsFunctionStatementFactory; + +#[inline(always)] +pub fn intersection_function_factory() -> IntersectionStatementFactory { + IntersectionStatementFactory::new(IntersectionFunctionFactory) +} diff --git a/mps-interpreter/src/lang/vocabulary/mod.rs b/mps-interpreter/src/lang/vocabulary/mod.rs index 27470da..94338ee 100644 --- a/mps-interpreter/src/lang/vocabulary/mod.rs +++ b/mps-interpreter/src/lang/vocabulary/mod.rs @@ -1,6 +1,7 @@ mod comment; mod empty; mod files; +mod intersection; mod repeat; mod reset; mod sql_init; @@ -12,6 +13,7 @@ mod variable_assign; pub use comment::{CommentStatement, CommentStatementFactory}; pub use empty::{empty_function_factory, EmptyStatementFactory}; pub use files::{files_function_factory, FilesStatementFactory}; +pub use intersection::{intersection_function_factory, IntersectionStatementFactory}; pub use repeat::{repeat_function_factory, RepeatStatementFactory}; pub use reset::{reset_function_factory, ResetStatementFactory}; pub use sql_init::{sql_init_function_factory, SqlInitStatementFactory}; diff --git a/mps-interpreter/src/lang/vocabulary/union.rs b/mps-interpreter/src/lang/vocabulary/union.rs index 44662b7..d8caf35 100644 --- a/mps-interpreter/src/lang/vocabulary/union.rs +++ b/mps-interpreter/src/lang/vocabulary/union.rs @@ -100,7 +100,7 @@ impl Iterator for UnionStatement { } fn size_hint(&self) -> (usize, Option) { - (0, Some(0)) + (0, None) } } @@ -179,7 +179,7 @@ pub fn union_function_factory() -> UnionStatementFactory { UnionStatementFactory::new(UnionFunctionFactory) } -fn next_comma(tokens: &VecDeque) -> Option { +pub(super) fn next_comma(tokens: &VecDeque) -> Option { let mut bracket_depth = 0; for i in 0..tokens.len() { let token = &tokens[i]; diff --git a/mps-interpreter/src/lib.rs b/mps-interpreter/src/lib.rs index 42a5eb5..582c594 100644 --- a/mps-interpreter/src/lib.rs +++ b/mps-interpreter/src/lib.rs @@ -139,6 +139,10 @@ //! //! Combine multiple iterables in a sequential pattern. All items in iterable1 are returned, then all items in iterable2, ... until all provided iterables are depleted. There is no limit on the amount of iterables which can be provided as parameters. //! +//! ### intersection(iterable1, iterable2, ...); +//! +//! Combine multiple iterables such that only items that exist in iterable1 and iterable2 and ... are returned. The order of items from iterable1 is maintained. There is no limit on the amount of iterables which can be provided as parameters. +//! //! ### empty(); //! //! Empty iterator. Useful for deleting items using replacement filters. diff --git a/mps-interpreter/tests/single_line.rs b/mps-interpreter/tests/single_line.rs index 7e1ece6..049b2cc 100644 --- a/mps-interpreter/tests/single_line.rs +++ b/mps-interpreter/tests/single_line.rs @@ -420,7 +420,7 @@ fn execute_unionfn_line() -> Result<(), Box> { true )?; execute_single_line( - "interlace(files(`~/Music/MusicFlac/Bruno Mars/24K Magic/`), files(`~/Music/MusicFlac/Bruno Mars/24K Magic/`))", + "interlace(empty(), files(`~/Music/MusicFlac/Bruno Mars/24K Magic/`))", false, true ) @@ -439,3 +439,27 @@ fn execute_regexfilter_line() -> Result<(), Box> { true, ) } + +#[test] +fn execute_intersectionfn_line() -> Result<(), Box> { + execute_single_line( + "intersection(files(`~/Music/MusicFlac/Bruno Mars/24K Magic/`))", + false, + true, + )?; + execute_single_line( + "n(files(`~/Music/MusicFlac/Bruno Mars/24K Magic/`), n(files(`~/Music/MusicFlac/Bruno Mars/24K Magic/`), files(`~/Music/MusicFlac/Bruno Mars/24K Magic/`)))", + false, + true, + )?; + execute_single_line( + "intersection(files(`~/Music/MusicFlac/Bruno Mars/24K Magic/`), files(`~/Music/MusicFlac/Bruno Mars/24K Magic/`))", + false, + true + )?; + execute_single_line( + "n(empty(), files(`~/Music/MusicFlac/Bruno Mars/24K Magic/`))", + true, + true + ) +} diff --git a/src/help.rs b/src/help.rs index 99de522..8c88978 100644 --- a/src/help.rs +++ b/src/help.rs @@ -42,6 +42,9 @@ These always return an iterable which can be manipulated. union(iterable1, iterable2, ...) Combine multiple iterables in a sequential pattern. All items in iterable1 are returned, then all items in iterable2, ... until all provided iterables are depleted. There is no limit on the amount of iterables which can be provided as parameters. + intersection(iterable1, iterable2, ...); + Combine multiple iterables such that only items that exist in iterable1 and iterable2 and ... are returned. The order of items from iterable1 is maintained. There is no limit on the amount of iterables which can be provided as parameters. + empty() Empty iterator. Useful for deleting items using replacement filters.";