use pyo3::prelude::*;
use std::cmp::Ordering;
#[pyclass]
#[derive(Debug)]
struct Point {
coords: Vec<f64>,
}
#[pymethods]
impl Point {
#[new]
fn new(coords: Vec<f64>) -> Self {
Point { coords }
}
}
#[derive(Debug)]
struct Node {
point: Vec<f64>,
left: Option<Box<Node>>,
right: Option<Box<Node>>,
}
impl Node {
fn new(point: Vec<f64>) -> Self {
Node {
point,
left: None,
right: None,
}
}
}
#[pyclass]
#[derive(Debug)]
pub struct KdTree {
root: Option<Box<Node>>,
dimensions: usize,
}
#[pymethods]
impl KdTree {
#[new]
pub fn new(dimensions: usize) -> Self {
KdTree {
root: None,
dimensions,
}
}
pub fn insert(&mut self, point: Vec<f64>) {
let root = self.root.take();
self.root = self.insert_recursive(root, point, 0);
}
}
impl KdTree {
fn insert_recursive(&mut self, node: Option<Box<Node>>, point: Vec<f64>, depth: usize) -> Option<Box<Node>> {
match node {
Some(mut n) => {
let dim = depth % self.dimensions;
let ordering = point[dim].partial_cmp(&n.point[dim]);
match ordering {
Some(Ordering::Less) | Some(Ordering::Equal) => {
n.left = self.insert_recursive(n.left.take(), point, depth + 1);
}
Some(Ordering::Greater) => {
n.right = self.insert_recursive(n.right.take(), point, depth + 1);
}
None => {
panic!("NaNs or unordered comparison encountered!");
}
}
Some(n)
}
None => Some(Box::new(Node::new(point))),
}
}
}