|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use crate::core::{Blob, Id, PlacedPoint, Point}; |
|
|
use crate::core::config::ArmsConfig; |
|
|
use crate::ports::{Near, NearResult, Place, PlaceResult, SearchResult}; |
|
|
use crate::adapters::storage::MemoryStorage; |
|
|
use crate::adapters::index::FlatIndex; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub struct Arms { |
|
|
|
|
|
config: ArmsConfig, |
|
|
|
|
|
|
|
|
storage: Box<dyn Place>, |
|
|
|
|
|
|
|
|
index: Box<dyn Near>, |
|
|
} |
|
|
|
|
|
impl Arms { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn new(config: ArmsConfig) -> Self { |
|
|
let storage = Box::new(MemoryStorage::new(config.dimensionality)); |
|
|
let index = Box::new(FlatIndex::new( |
|
|
config.dimensionality, |
|
|
config.proximity.clone(), |
|
|
true, |
|
|
)); |
|
|
|
|
|
Self { |
|
|
config, |
|
|
storage, |
|
|
index, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn with_adapters( |
|
|
config: ArmsConfig, |
|
|
storage: Box<dyn Place>, |
|
|
index: Box<dyn Near>, |
|
|
) -> Self { |
|
|
Self { |
|
|
config, |
|
|
storage, |
|
|
index, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn config(&self) -> &ArmsConfig { |
|
|
&self.config |
|
|
} |
|
|
|
|
|
|
|
|
pub fn dimensionality(&self) -> usize { |
|
|
self.config.dimensionality |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn place(&mut self, point: Point, blob: Blob) -> PlaceResult<Id> { |
|
|
|
|
|
let point = if self.config.normalize_on_insert { |
|
|
point.normalize() |
|
|
} else { |
|
|
point |
|
|
}; |
|
|
|
|
|
|
|
|
let id = self.storage.place(point.clone(), blob)?; |
|
|
|
|
|
|
|
|
if let Err(e) = self.index.add(id, &point) { |
|
|
|
|
|
self.storage.remove(id); |
|
|
return Err(crate::ports::PlaceError::StorageError(format!( |
|
|
"Index error: {:?}", |
|
|
e |
|
|
))); |
|
|
} |
|
|
|
|
|
Ok(id) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn place_batch(&mut self, items: Vec<(Point, Blob)>) -> Vec<PlaceResult<Id>> { |
|
|
items |
|
|
.into_iter() |
|
|
.map(|(point, blob)| self.place(point, blob)) |
|
|
.collect() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn remove(&mut self, id: Id) -> Option<PlacedPoint> { |
|
|
|
|
|
let _ = self.index.remove(id); |
|
|
|
|
|
|
|
|
self.storage.remove(id) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn get(&self, id: Id) -> Option<&PlacedPoint> { |
|
|
self.storage.get(id) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn contains(&self, id: Id) -> bool { |
|
|
self.storage.contains(id) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn len(&self) -> usize { |
|
|
self.storage.len() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn is_empty(&self) -> bool { |
|
|
self.storage.is_empty() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn clear(&mut self) { |
|
|
self.storage.clear(); |
|
|
let _ = self.index.rebuild(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> { |
|
|
|
|
|
let query = if self.config.normalize_on_insert { |
|
|
query.normalize() |
|
|
} else { |
|
|
query.clone() |
|
|
}; |
|
|
|
|
|
self.index.near(&query, k) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> { |
|
|
let query = if self.config.normalize_on_insert { |
|
|
query.normalize() |
|
|
} else { |
|
|
query.clone() |
|
|
}; |
|
|
|
|
|
self.index.within(&query, threshold) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn near_with_data(&self, query: &Point, k: usize) -> NearResult<Vec<(&PlacedPoint, f32)>> { |
|
|
let results = self.near(query, k)?; |
|
|
|
|
|
Ok(results |
|
|
.into_iter() |
|
|
.filter_map(|r| self.storage.get(r.id).map(|p| (p, r.score))) |
|
|
.collect()) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn merge(&self, points: &[Point]) -> Point { |
|
|
self.config.merge.merge(points) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn proximity(&self, a: &Point, b: &Point) -> f32 { |
|
|
self.config.proximity.proximity(a, b) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn size_bytes(&self) -> usize { |
|
|
self.storage.size_bytes() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn index_len(&self) -> usize { |
|
|
self.index.len() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn is_ready(&self) -> bool { |
|
|
self.index.is_ready() |
|
|
} |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
fn create_test_arms() -> Arms { |
|
|
Arms::new(ArmsConfig::new(3)) |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_arms_place_and_get() { |
|
|
let mut arms = create_test_arms(); |
|
|
|
|
|
let point = Point::new(vec![1.0, 0.0, 0.0]); |
|
|
let blob = Blob::from_str("test data"); |
|
|
|
|
|
let id = arms.place(point, blob).unwrap(); |
|
|
|
|
|
let retrieved = arms.get(id).unwrap(); |
|
|
assert_eq!(retrieved.blob.as_str(), Some("test data")); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_arms_near() { |
|
|
let mut arms = create_test_arms(); |
|
|
|
|
|
|
|
|
arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::from_str("x")).unwrap(); |
|
|
arms.place(Point::new(vec![0.0, 1.0, 0.0]), Blob::from_str("y")).unwrap(); |
|
|
arms.place(Point::new(vec![0.0, 0.0, 1.0]), Blob::from_str("z")).unwrap(); |
|
|
|
|
|
|
|
|
let query = Point::new(vec![1.0, 0.0, 0.0]); |
|
|
let results = arms.near(&query, 2).unwrap(); |
|
|
|
|
|
assert_eq!(results.len(), 2); |
|
|
|
|
|
assert!(results[0].score > results[1].score); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_arms_near_with_data() { |
|
|
let mut arms = create_test_arms(); |
|
|
|
|
|
arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::from_str("x")).unwrap(); |
|
|
arms.place(Point::new(vec![0.0, 1.0, 0.0]), Blob::from_str("y")).unwrap(); |
|
|
|
|
|
let query = Point::new(vec![1.0, 0.0, 0.0]); |
|
|
let results = arms.near_with_data(&query, 1).unwrap(); |
|
|
|
|
|
assert_eq!(results.len(), 1); |
|
|
assert_eq!(results[0].0.blob.as_str(), Some("x")); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_arms_remove() { |
|
|
let mut arms = create_test_arms(); |
|
|
|
|
|
let id = arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::empty()).unwrap(); |
|
|
|
|
|
assert!(arms.contains(id)); |
|
|
assert_eq!(arms.len(), 1); |
|
|
|
|
|
arms.remove(id); |
|
|
|
|
|
assert!(!arms.contains(id)); |
|
|
assert_eq!(arms.len(), 0); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_arms_merge() { |
|
|
let arms = create_test_arms(); |
|
|
|
|
|
let points = vec![ |
|
|
Point::new(vec![1.0, 0.0, 0.0]), |
|
|
Point::new(vec![0.0, 1.0, 0.0]), |
|
|
]; |
|
|
|
|
|
let merged = arms.merge(&points); |
|
|
|
|
|
|
|
|
assert!((merged.dims()[0] - 0.5).abs() < 0.0001); |
|
|
assert!((merged.dims()[1] - 0.5).abs() < 0.0001); |
|
|
assert!((merged.dims()[2] - 0.0).abs() < 0.0001); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_arms_clear() { |
|
|
let mut arms = create_test_arms(); |
|
|
|
|
|
for i in 0..10 { |
|
|
arms.place(Point::new(vec![i as f32, 0.0, 0.0]), Blob::empty()).unwrap(); |
|
|
} |
|
|
|
|
|
assert_eq!(arms.len(), 10); |
|
|
|
|
|
arms.clear(); |
|
|
|
|
|
assert_eq!(arms.len(), 0); |
|
|
assert!(arms.is_empty()); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_arms_normalizes_on_insert() { |
|
|
let mut arms = create_test_arms(); |
|
|
|
|
|
|
|
|
let point = Point::new(vec![3.0, 4.0, 0.0]); |
|
|
let id = arms.place(point, Blob::empty()).unwrap(); |
|
|
|
|
|
let retrieved = arms.get(id).unwrap(); |
|
|
|
|
|
|
|
|
assert!(retrieved.point.is_normalized()); |
|
|
} |
|
|
} |
|
|
|