从零开始的 Rust 学习笔记(16) —— K-means 模版

重构一下前两天用 Rust 写的 Colorline 中 K-means 聚类算法的部分~因为之前 kmeans 放在了 dominant_color.rs 下,显然 kmeans 这个算法不应该属于 dominant_color;同时,之前的 kmeans 算法只能用在这里,考虑到以后代码复用的话,当然是要写成模版啦╮(╯▽╰)╭

kmeans 独立出来之后,让 kmeans 可以接受任意 impl 了 KmeansComputable trait 的类。其实很久以前用 C++ 也写了一个比较通用的 K-means 模版,但是当时并没有考虑 trait 这样的,而是直接用了两个回调函数(不过写完这个 Rust 版本的之后似乎突然有点思路了)

Rust 这个写起来思路很清晰,首先就是 kmeans 函数应该接受:

  1. 一组待聚类的数据 array
  2. 要求聚类的的类数 k
  3. 收敛条件 min_diff —— k 个类每次迭代后各类中心点移动距离的上界

其中,array 应该是 KmeansComputable 的。

那么 KmeansComputable 这个 trait 的设计的话,第一点显然是要可以给出该类任意两个 instance 之间的距离;第二点则是可以在给出一组该类的后计算其中心点。也就是

pub trait KmeansComputable {
    fn distance(&self, other: &Self) -> f64;
    fn compute_center(cluster: &Vec<Self>) -> Self where Self: Sized;
}

于是 kmeans 函数如下~(高亮的部分则是用到 KmeansComputable trait 里要求实现的函数的地方)

/// K-means
pub fn kmeans<T: Clone + KmeansComputable>(array: &Vec<T>, k: u32, min_diff: f64) -> Vec<T> {
    type Cluster<T> = Vec<T>;
    
    let mut clusters: Cluster<T> = vec![];
    let randmax: usize = array.len();

    // randomly choose k points as initial cluster center
    for _ in 0..k {
        clusters.push(array[rand::random::<usize>() % randmax].clone());
    }
    
    loop {
        // initialize k clusters in this round
        let mut points: HashMap<u32, Vec<T>> = HashMap::new();
        for i in 0..k {
            let value: Vec<T> = vec![];
            points.insert(i, value);
        }
        
        // find the nearest cluster for each pixel
        for element in array {
            let mut nearest_distance = std::f64::MAX;
            let mut nearest_index: u32 = 0;
            
            // calcuate the distance to each cluster
            for i in 0..k {
                // calculate the distance between current pixel from i-th clsuter
                let distance = element.distance(&clusters[i as usize]);
                // if the distance is nearer
                if distance < nearest_distance {
                    // update neatest distance
                    nearest_distance = distance;
                    // update the cluster id to current pixel
                    nearest_index = i;
                }
            }
            
            // assign current item to its nearest cluster
            if let Some(element_entry) = points.get_mut(&nearest_index) {
                (*element_entry).push((*element).clone());
            }
        }
        
        // recalculate center for each cluster
        let mut diff: f64 = 0.0;
        for i in 0..k {
            // store old center
            let old_center = &clusters[i as usize];
            // compute new center
            let new_center = T::compute_center(&points[&i]);
            // the distance that center moved
            let dist = old_center.distance(&new_center);
            // assign new center to cluster[i]
            clusters[i as usize] = new_center;
            // record max moved distance among `k` clusters
            if dist > diff {
                diff = dist;
            }
        }
        
        // if it's stable
        if diff < min_diff {
            break;
        }
    }
    
    clusters
}

那么 Colorline 里原有的 ColorCount 如下

#[derive(Clone)]
pub struct ColorCount {
    color: Color,
    count: u64,
}

impl ColorCount {
    pub fn new(color: &Color, count: u64) -> Self {
        ColorCount {
            color: color.clone(),
            count: count,
        }
    }
    
    pub fn color(&self) -> Color {
        self.color.clone()
    }
}

接下来就是给 ColorCount 实现 KmeansComputable

impl KmeansComputable for ColorCount {
    /// Compute distance from the other instance
    fn distance(&self, other: &Self) -> f64 {
        // euclidean distance, the classic
        let mut distance: f64 = 0.0;
        distance += f64::powf(self.color.b - other.color.b, 2.0);
        distance += f64::powf(self.color.g - other.color.g, 2.0);
        distance += f64::powf(self.color.r - other.color.r, 2.0);
        f64::sqrt(distance)
    }
    
    /// compute the center from the given instance array
    fn compute_center(cluster: &Vec<Self>) -> Self {
        let mut total_count: f64 = 0.0;
        let mut vals = Color::new(0.0, 0.0, 0.0);
        
        cluster.iter().for_each(|ref color_count| {
            let count = color_count.count as f64;
            total_count += count;
        
            vals.b += color_count.color.b * count;
            vals.g += color_count.color.g * count;
            vals.r += color_count.color.r * count;
        });
        
        vals.b /= total_count;
        vals.g /= total_count;
        vals.r /= total_count;
        
        ColorCount::new(&vals, 0)
    }
}

最后,在原来的 dominant_color.rs 里最后的调用就变成了

// transform HashMap to Vec for k-means algorithm
let pixels: Vec<ColorCount> = color_counter.iter().map(|(&color, &count)| ColorCount::new(&color, count)).collect();

// k-means
//   we assuming it's stable if the max moved distance is below 1.0
Ok(kmeans(&pixels, k, 1.0).iter().map(|ref color_count| color_count.color()).collect())
声明: 本文为0xBBC原创, 转载注明出处喵~

Leave a Reply

Your email address will not be published. Required fields are marked *