|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use super::Point; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub trait Merge: Send + Sync { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn merge(&self, points: &[Point]) -> Point; |
|
|
|
|
|
|
|
|
fn name(&self) -> &'static str; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Clone, Copy, Debug, Default)] |
|
|
pub struct Mean; |
|
|
|
|
|
impl Merge for Mean { |
|
|
fn merge(&self, points: &[Point]) -> Point { |
|
|
assert!(!points.is_empty(), "Cannot merge empty slice"); |
|
|
|
|
|
let dims = points[0].dimensionality(); |
|
|
let n = points.len() as f32; |
|
|
|
|
|
let mut result = vec![0.0; dims]; |
|
|
for p in points { |
|
|
assert_eq!( |
|
|
p.dimensionality(), |
|
|
dims, |
|
|
"All points must have same dimensionality" |
|
|
); |
|
|
for (r, d) in result.iter_mut().zip(p.dims()) { |
|
|
*r += d / n; |
|
|
} |
|
|
} |
|
|
|
|
|
Point::new(result) |
|
|
} |
|
|
|
|
|
fn name(&self) -> &'static str { |
|
|
"mean" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Clone, Debug)] |
|
|
pub struct WeightedMean { |
|
|
weights: Vec<f32>, |
|
|
} |
|
|
|
|
|
impl WeightedMean { |
|
|
|
|
|
|
|
|
|
|
|
pub fn new(weights: Vec<f32>) -> Self { |
|
|
Self { weights } |
|
|
} |
|
|
|
|
|
|
|
|
pub fn uniform(n: usize) -> Self { |
|
|
Self { |
|
|
weights: vec![1.0; n], |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn recency(n: usize, decay: f32) -> Self { |
|
|
let weights: Vec<f32> = (0..n).map(|i| decay.powi((n - 1 - i) as i32)).collect(); |
|
|
Self { weights } |
|
|
} |
|
|
} |
|
|
|
|
|
impl Merge for WeightedMean { |
|
|
fn merge(&self, points: &[Point]) -> Point { |
|
|
assert!(!points.is_empty(), "Cannot merge empty slice"); |
|
|
assert_eq!( |
|
|
points.len(), |
|
|
self.weights.len(), |
|
|
"Number of points must match number of weights" |
|
|
); |
|
|
|
|
|
let dims = points[0].dimensionality(); |
|
|
let total_weight: f32 = self.weights.iter().sum(); |
|
|
|
|
|
let mut result = vec![0.0; dims]; |
|
|
for (p, &w) in points.iter().zip(&self.weights) { |
|
|
assert_eq!( |
|
|
p.dimensionality(), |
|
|
dims, |
|
|
"All points must have same dimensionality" |
|
|
); |
|
|
let normalized_w = w / total_weight; |
|
|
for (r, d) in result.iter_mut().zip(p.dims()) { |
|
|
*r += d * normalized_w; |
|
|
} |
|
|
} |
|
|
|
|
|
Point::new(result) |
|
|
} |
|
|
|
|
|
fn name(&self) -> &'static str { |
|
|
"weighted_mean" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Clone, Copy, Debug, Default)] |
|
|
pub struct MaxPool; |
|
|
|
|
|
impl Merge for MaxPool { |
|
|
fn merge(&self, points: &[Point]) -> Point { |
|
|
assert!(!points.is_empty(), "Cannot merge empty slice"); |
|
|
|
|
|
let dims = points[0].dimensionality(); |
|
|
let mut result = points[0].dims().to_vec(); |
|
|
|
|
|
for p in &points[1..] { |
|
|
assert_eq!( |
|
|
p.dimensionality(), |
|
|
dims, |
|
|
"All points must have same dimensionality" |
|
|
); |
|
|
for (r, d) in result.iter_mut().zip(p.dims()) { |
|
|
*r = r.max(*d); |
|
|
} |
|
|
} |
|
|
|
|
|
Point::new(result) |
|
|
} |
|
|
|
|
|
fn name(&self) -> &'static str { |
|
|
"max_pool" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Clone, Copy, Debug, Default)] |
|
|
pub struct MinPool; |
|
|
|
|
|
impl Merge for MinPool { |
|
|
fn merge(&self, points: &[Point]) -> Point { |
|
|
assert!(!points.is_empty(), "Cannot merge empty slice"); |
|
|
|
|
|
let dims = points[0].dimensionality(); |
|
|
let mut result = points[0].dims().to_vec(); |
|
|
|
|
|
for p in &points[1..] { |
|
|
assert_eq!( |
|
|
p.dimensionality(), |
|
|
dims, |
|
|
"All points must have same dimensionality" |
|
|
); |
|
|
for (r, d) in result.iter_mut().zip(p.dims()) { |
|
|
*r = r.min(*d); |
|
|
} |
|
|
} |
|
|
|
|
|
Point::new(result) |
|
|
} |
|
|
|
|
|
fn name(&self) -> &'static str { |
|
|
"min_pool" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Clone, Copy, Debug, Default)] |
|
|
pub struct Sum; |
|
|
|
|
|
impl Merge for Sum { |
|
|
fn merge(&self, points: &[Point]) -> Point { |
|
|
assert!(!points.is_empty(), "Cannot merge empty slice"); |
|
|
|
|
|
let dims = points[0].dimensionality(); |
|
|
let mut result = vec![0.0; dims]; |
|
|
|
|
|
for p in points { |
|
|
assert_eq!( |
|
|
p.dimensionality(), |
|
|
dims, |
|
|
"All points must have same dimensionality" |
|
|
); |
|
|
for (r, d) in result.iter_mut().zip(p.dims()) { |
|
|
*r += d; |
|
|
} |
|
|
} |
|
|
|
|
|
Point::new(result) |
|
|
} |
|
|
|
|
|
fn name(&self) -> &'static str { |
|
|
"sum" |
|
|
} |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
#[test] |
|
|
fn test_mean_single() { |
|
|
let points = vec![Point::new(vec![1.0, 2.0, 3.0])]; |
|
|
let merged = Mean.merge(&points); |
|
|
assert_eq!(merged.dims(), &[1.0, 2.0, 3.0]); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_mean_multiple() { |
|
|
let points = vec![ |
|
|
Point::new(vec![1.0, 2.0]), |
|
|
Point::new(vec![3.0, 4.0]), |
|
|
]; |
|
|
let merged = Mean.merge(&points); |
|
|
assert_eq!(merged.dims(), &[2.0, 3.0]); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_weighted_mean() { |
|
|
let points = vec![ |
|
|
Point::new(vec![0.0, 0.0]), |
|
|
Point::new(vec![10.0, 10.0]), |
|
|
]; |
|
|
|
|
|
let merger = WeightedMean::new(vec![1.0, 3.0]); |
|
|
let merged = merger.merge(&points); |
|
|
|
|
|
assert!((merged.dims()[0] - 7.5).abs() < 0.0001); |
|
|
assert!((merged.dims()[1] - 7.5).abs() < 0.0001); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_weighted_mean_recency() { |
|
|
let merger = WeightedMean::recency(3, 0.5); |
|
|
|
|
|
|
|
|
assert_eq!(merger.weights.len(), 3); |
|
|
assert!((merger.weights[0] - 0.25).abs() < 0.0001); |
|
|
assert!((merger.weights[1] - 0.5).abs() < 0.0001); |
|
|
assert!((merger.weights[2] - 1.0).abs() < 0.0001); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_max_pool() { |
|
|
let points = vec![ |
|
|
Point::new(vec![1.0, 5.0, 2.0]), |
|
|
Point::new(vec![3.0, 2.0, 4.0]), |
|
|
Point::new(vec![2.0, 3.0, 1.0]), |
|
|
]; |
|
|
let merged = MaxPool.merge(&points); |
|
|
assert_eq!(merged.dims(), &[3.0, 5.0, 4.0]); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_min_pool() { |
|
|
let points = vec![ |
|
|
Point::new(vec![1.0, 5.0, 2.0]), |
|
|
Point::new(vec![3.0, 2.0, 4.0]), |
|
|
Point::new(vec![2.0, 3.0, 1.0]), |
|
|
]; |
|
|
let merged = MinPool.merge(&points); |
|
|
assert_eq!(merged.dims(), &[1.0, 2.0, 1.0]); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_sum() { |
|
|
let points = vec![ |
|
|
Point::new(vec![1.0, 2.0]), |
|
|
Point::new(vec![3.0, 4.0]), |
|
|
]; |
|
|
let merged = Sum.merge(&points); |
|
|
assert_eq!(merged.dims(), &[4.0, 6.0]); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_merge_names() { |
|
|
assert_eq!(Mean.name(), "mean"); |
|
|
assert_eq!(MaxPool.name(), "max_pool"); |
|
|
assert_eq!(MinPool.name(), "min_pool"); |
|
|
assert_eq!(Sum.name(), "sum"); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
#[should_panic(expected = "Cannot merge empty")] |
|
|
fn test_merge_empty_panics() { |
|
|
let points: Vec<Point> = vec![]; |
|
|
Mean.merge(&points); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
#[should_panic(expected = "same dimensionality")] |
|
|
fn test_merge_dimension_mismatch_panics() { |
|
|
let points = vec![ |
|
|
Point::new(vec![1.0, 2.0]), |
|
|
Point::new(vec![1.0, 2.0, 3.0]), |
|
|
]; |
|
|
Mean.merge(&points); |
|
|
} |
|
|
} |
|
|
|