C++下的通用K-means Algorithm模版

之前写过一篇利用k-means算法来计算图像中主要颜色的文章,K-means聚类算法计算给定图像中主要颜色。于是今天顺便写了个较为通用的C++下的K-means算法模版。

有一个主要的模版,还有一个稍有变化的模版,仅体现在传入的第三个参数上。

主要的参数如下,

Parameter Description
k 聚类的种数
min_diff 收束条件,旧的cluster中心与新的之间变化的距离
data1 std::vector<数据类型>
data2 std::vector<std::pair<数据类型, 该实例的统计个数>>
center 传入std::vector<std::pair<数据类型, 该实例的统计个数>>,返回该组数据的中心值
distance 传入给定数据中的两个元素, 返回它们的距离

返回值都是包含k个该种类型的元素的std::vector

一个简单的例子如下

#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <iterator>
#include <vector>
#include "kmeans.hpp"

using namespace std;

int main(int argc, const char * argv[]) {
    vector<double> data;

    for (int i = 0; i < 10; i++) {
        data.emplace_back(233 + random() % 10);
        data.emplace_back(233 - random() % 10);
    }
    for (int i = 0; i < 10; i++) {
        data.emplace_back(2333 + random() % 10);
        data.emplace_back(2333 - random() % 10);
    }
    for (int i = 0; i < 10; i++) {
        data.emplace_back(23333 + random() % 10);
        data.emplace_back(23333 - random() % 10);
    }

    vector<double> result = kmeans<double>(3, 2, data, [&](const std::vector<std::pair<double, uint32_t>>& data) -> double {
        double total = 0, count = 0;

        for_each(data.cbegin(), data.cend(), [&total, &count](const std::pair<double, uint32_t>& point){
            total += point.first;
            count += point.second;
        });
        return total/count;
    }, [&](const double& first, const double& second) -> double {
        double dis = first - second;
        return dis < 0 ? -dis : dis;
    });

    copy(result.cbegin(), result.cend(), ostream_iterator<double>(cout, " "));
    return 0;
}

screenshot

以下是kmeans.hpp的代码,最新的kmeans.hpp可以在我的github上获取到,BlueCocoa/kmeans-template

/**
 *  @brief K-means algorithm
 *
 *  @param k        Number of clusters
 *  @param min_diff Minimum difference
 *  @param data     Data to cluster
 *  @param center   Custom center calcualation
 *  @param distance Custom distance calculation
 */
template <class _Element>
std::vector<_Element> kmeans(ssize_t k, double min_diff, const std::vector<_Element>& data, const std::function<_Element(const std::vector<std::pair<_Element, uint32_t>>& data)>& center, const std::function<double(const _Element& first, const _Element& second)>& distance);

/**
 *  @brief K-means algorithm
 *
 *  @param k        Number of clusters
 *  @param min_diff Minimum difference
 *  @param data     Data (with count) to cluster
 *  @param center   Custom center calcualation
 *  @param distance Custom distance calculation
 */
template <class _Element>
std::vector<_Element> kmeans(ssize_t k, double min_diff, const std::vector<std::pair<_Element, uint32_t>>& data, const std::function<_Element(const std::vector<std::pair<_Element, uint32_t>>& data)>& center, const std::function<double(const _Element& first, const _Element& second)>& distance);

template <class _Element>
std::vector<_Element> kmeans(ssize_t k, double min_diff, const std::vector<_Element>& data, const std::function<_Element(const std::vector<std::pair<_Element, uint32_t>>& data)>& center, const std::function<double(const _Element& first, const _Element& second)>& distance) {
    using ElementCount  = std::pair<_Element, uint32_t>;
    std::vector<ElementCount> elements;
    std::for_each(data.cbegin(), data.cend(), [&](const _Element& element) {
        elements.emplace_back(element, 1);
    });
    return kmeans<_Element>(k, min_diff, elements, center, distance);
}

template <class _Element>
std::vector<_Element> kmeans(ssize_t k, double min_diff, const std::vector<std::pair<_Element, uint32_t>>& data, const std::function<_Element(const std::vector<std::pair<_Element, uint32_t>>& data)>& center, const std::function<double(const _Element& first, const _Element& second)>& distance) {
    using ElementCount  = std::pair<_Element, uint32_t>;
    using Cluster       = std::vector<_Element>;
    using ClusterPoint  = std::map<uint32_t, std::vector<ElementCount>>;

    Cluster clusters;
    uint32_t randmax = static_cast<uint32_t>(data.size());

    for (uint32_t i = 0; i < k; i++) {
        auto iter = data.cbegin();
        for (uint32_t t = 0; t < arc4random() % randmax; t++, iter++);
        clusters.emplace_back(iter->first);
    }

    while (1) {
        ClusterPoint points;

        std::for_each(data.cbegin(), data.cend(), [&](const std::pair<_Element, uint32_t>& point) {
            double smallestDistance = DBL_MAX;
            double dis;
            uint32_t smallestIndex;
            for (uint32_t i = 0; i < k; i++) {
                dis = distance(point.first, clusters[i]);

                if (dis < smallestDistance) {
                    smallestDistance = dis;
                    smallestIndex = i;
                }
            }
            points[smallestIndex].emplace_back(point);
        });

        double diff = 0;
        for (uint32_t i = 0; i < k; i++) {
            _Element oldCenter = clusters[i];
            _Element newCenter = center(points[i]);
            clusters[i] = newCenter;
            double dis = distance(oldCenter, newCenter);
            diff = std::max(diff, dis);
        }

        if (diff < min_diff) {
            break;
        }
    }

    return clusters;
}
声明: 本文为0xBBC原创, 转载注明出处喵~

发表评论

电子邮件地址不会被公开。 必填项已用*标注