`
cyxlgzs
  • 浏览: 90282 次
  • 性别: Icon_minigender_1
社区版块
存档分类
最新评论

算法——K均值聚类算法扩展应用(Java实现)

 
阅读更多

1、前面一篇文章算法——K均值聚类算法(Java实现)简单的实现了一下K均值分类算法,这节我们对于他的应用进行一个扩展应用

2、目标为对对象的分类

3、具体实现如下

1)首先建立一个基类KmeansObject,目的为继承该类的子类都可以应用我们的k均值算法进行分类,代码如下

package org.cyxl.util.algorithm;

/**
 * 所有使用k均值分类算法的对象都必须继承自该对象
 * @author cyxl
 * @version 1.0 2012-05-24
 * @since 1.0
 *
 */
public class KmeansObject {
	public float compare;		//比较因子
}

2)算法实现,代码如下

package org.cyxl.util.algorithm;

import java.util.ArrayList;
import java.util.Random;

/**
 * K均值聚类算法
 */
public class CommonKmeans {
	private int k;// 分成多少簇
	private int m;// 迭代次数
	private int dataSetLength;// 数据集元素个数,即数据集的长度
	private ArrayList<KmeansObject> dataSet;// 数据集链表
	private ArrayList<KmeansObject> center;// 中心链表
	private ArrayList<ArrayList<KmeansObject>> cluster; // 簇
	private ArrayList<Float> jc;// 误差平方和,k越接近dataSetLength,误差越小
	private Random random;

	/**
	 * 设置需分组的原始数据集
	 * 
	 * @param dataSet
	 */

	public void setDataSet(ArrayList<KmeansObject> dataSet) {
		this.dataSet = dataSet;
	}

	/**
	 * 获取结果分组
	 * 
	 * @return 结果集
	 */

	public ArrayList<ArrayList<KmeansObject>> getCluster() {
		return cluster;
	}

	/**
	 * 构造函数,传入需要分成的簇数量
	 * 
	 * @param k
	 *            簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度
	 */
	public CommonKmeans(int k) {
		if (k <= 0) {
			k = 1;
		}
		this.k = k;
	}

	/**
	 * 初始化
	 */
	private void init() {
		m = 0;
		random = new Random();
		if (dataSet == null || dataSet.size() == 0) {
			initDataSet();
		}
		dataSetLength = dataSet.size();
		if (k > dataSetLength) {
			k = dataSetLength;
		}
		center = initCenters();
		cluster = initCluster();
		jc = new ArrayList<Float>();
	}

	/**
	 * 如果调用者未初始化数据集,则采用内部测试数据集
	 */
	private void initDataSet() {
		dataSet = new ArrayList<KmeansObject>();
		
		for(int i=0;i<10;i++)
		{
			int temp = random.nextInt(100);
			KmeansObject ko=new KmeansObject();
			ko.compare=temp;
			dataSet.add(ko);
		}
	}

	/**
	 * 初始化中心数据链表,分成多少簇就有多少个中心点
	 * 
	 * @return 中心点集
	 */
	private ArrayList<KmeansObject> initCenters() {
		ArrayList<KmeansObject> center = new ArrayList<KmeansObject>();
		int[] randoms = new int[k];
		boolean flag;
		int temp = random.nextInt(dataSetLength);
		randoms[0] = temp;
		for (int i = 1; i < k; i++) {
			flag = true;
			while (flag) {
				temp = random.nextInt(dataSetLength);
				int j = 0;
				// 不清楚for循环导致j无法加1
				// for(j=0;j<i;++j)
				// {
				// if(temp==randoms[j]);
				// {
				// break;
				// }
				// }
				while (j < i) {
					if (temp == randoms[j]) {
						break;
					}
					j++;
				}
				if (j == i) {
					flag = false;
				}
			}
			randoms[i] = temp;
		}

		for (int i = 0; i < k; i++) {
			center.add(dataSet.get(randoms[i]));// 生成初始化中心链表
		}
		return center;
	}

	/**
	 * 初始化簇集合
	 * 
	 * @return 一个分为k簇的空数据的簇集合
	 */
	private ArrayList<ArrayList<KmeansObject>> initCluster() {
		ArrayList<ArrayList<KmeansObject>> cluster = new ArrayList<ArrayList<KmeansObject>>();
		for (int i = 0; i < k; i++) {
			cluster.add(new ArrayList<KmeansObject>());
		}

		return cluster;
	}

	/**
	 * 计算两个点之间的距离
	 * 
	 * @param element
	 *            点1
	 * @param center
	 *            点2
	 * @return 距离
	 */
	private float distance(KmeansObject element, KmeansObject center) {
		float distance = 0.0f;

		distance=Math.abs(element.compare-center.compare);
		
		return distance;
	}

	/**
	 * 获取距离集合中最小距离的位置
	 * 
	 * @param distance
	 *            距离数组
	 * @return 最小距离在距离数组中的位置
	 */
	private int minDistance(float[] distance) {
		float minDistance = distance[0];
		int minLocation = 0;
		for (int i = 1; i < distance.length; i++) {
			if (distance[i] < minDistance) {
				minDistance = distance[i];
				minLocation = i;
			} else if (distance[i] == minDistance) // 如果相等,随机返回一个位置
			{
				if (random.nextInt(10) < 5) {
					minLocation = i;
				}
			}
		}

		return minLocation;
	}

	/**
	 * 核心,将当前元素放到最小距离中心相关的簇中
	 */
	private void clusterSet() {
		float[] distance = new float[k];
		for (int i = 0; i < dataSetLength; i++) {
			for (int j = 0; j < k; j++) {
				distance[j] = distance(dataSet.get(i), center.get(j));

			}
			int minLocation = minDistance(distance);

			cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心相关的簇中

		}
	}

	/**
	 * 求两点误差平方的方法
	 * 
	 * @param element
	 *            点1
	 * @param center
	 *            点2
	 * @return 误差平方
	 */
	private float errorSquare(KmeansObject element, KmeansObject center) {
		float x = Math.abs(element.compare-center.compare);
		
		float errSquare = x * x;

		return errSquare;
	}

	/**
	 * 计算误差平方和准则函数方法
	 */
	private void countRule() {
		float jcF = 0;
		for (int i = 0; i < cluster.size(); i++) {
			for (int j = 0; j < cluster.get(i).size(); j++) {
				jcF += errorSquare(cluster.get(i).get(j), center.get(i));

			}
		}
		jc.add(jcF);
	}

	/**
	 * 设置新的簇中心方法
	 */
	private void setNewCenter() {
		for (int i = 0; i < k; i++) {
			int n = cluster.get(i).size();
			if (n != 0) {
				KmeansObject newCenter = new KmeansObject();
				for (int j = 0; j < n; j++) {
					newCenter.compare += cluster.get(i).get(j).compare;
				}
				// 设置一个平均值
				newCenter.compare=newCenter.compare/n;
				
				center.set(i, newCenter);
			}
		}
	}

	/**
	 * 打印数据,测试用
	 * 
	 * @param dataArray
	 *            数据集
	 * @param dataArrayName
	 *            数据集名称
	 */
	public void printDataArray(ArrayList<KmeansObject> dataArray,
			String dataArrayName) {
		for (int i = 0; i < dataArray.size(); i++) {
			System.out.println("print:" + dataArrayName + "[" + i + "]={"
					+ dataArray.get(i) + "}");
		}
		System.out.println("===================================");
	}

	/**
	 * Kmeans算法核心过程方法
	 */
	private void kmeans() {
		init();

		// 循环分组,直到误差不变为止
		while (true) {
			clusterSet();

			countRule();
			
			// 误差不变了,分组完成
			if (m != 0) {
				if (jc.get(m) - jc.get(m - 1) == 0) {
					break;
				}
			}

			setNewCenter();
			m++;
			cluster.clear();
			cluster = initCluster();
		}
		
	}

	/**
	 * 执行算法
	 */
	public void execute() {
		long startTime = System.currentTimeMillis();
		System.out.println("kmeans begins");
		kmeans();
		long endTime = System.currentTimeMillis();
		System.out.println("kmeans running time=" + (endTime - startTime)
				+ "ms");
		System.out.println("kmeans ends");
		System.out.println();
	}
	
	
}

3)测试算法,首先建立一个Person类,目标在于对人进行分类

package org.cyxl.util.algorithm;

public class Person extends KmeansObject {
	String name="";
	int age=0;
	float qz=1;		//权重
	
	public Person(){}
	
	public Person(String name,int age,float qz)
	{
		this.name=name;
		this.age=age;
		this.qz=qz;
	}
	
	public String getName() {
		return name;
	}
	public void setName(String name) {
		this.name = name;
	}
	public int getAge() {
		return age;
	}
	public void setAge(int age) {
		this.age = age;
	}
	
	public float getQz() {
		return qz;
	}

	public void setQz(float qz) {
		this.qz = qz;
	}

	public String toString()
	{
		return "name:"+this.name+";age:"+this.age+";qz:"+this.qz+";compare:"+super.compare;
	}
}

4)客户端测试代码

                CommonKmeans k=new CommonKmeans(5);
		ArrayList<KmeansObject> list=new ArrayList<KmeansObject>();
		
		for(int i=0;i<10;i++)
		{
			float qz=(float)(new Random().nextInt(10))/10;
			Person p=new Person("name"+i,i,qz);
			p.compare=new Random().nextInt(100)*p.getQz();
			list.add(p);
		}
		k.setDataSet(list);
		k.printDataArray(k.dataSet, "before");
		k.execute();
		ArrayList<ArrayList<KmeansObject>> cluster=k.getCluster();
		//查看结果
		for(int i=0;i<cluster.size();i++)
		{
			k.printDataArray(cluster.get(i), "cluster["+i+"]");
		}

5)输出结果

print:before[0]={name:name0;age:0;qz:0.0;compare:0.0}
print:before[1]={name:name1;age:1;qz:0.9;compare:48.6}
print:before[2]={name:name2;age:2;qz:0.9;compare:57.6}
print:before[3]={name:name3;age:3;qz:0.4;compare:28.4}
print:before[4]={name:name4;age:4;qz:0.0;compare:0.0}
print:before[5]={name:name5;age:5;qz:0.4;compare:33.600002}
print:before[6]={name:name6;age:6;qz:0.5;compare:2.0}
print:before[7]={name:name7;age:7;qz:0.2;compare:14.6}
print:before[8]={name:name8;age:8;qz:0.6;compare:5.4}
print:before[9]={name:name9;age:9;qz:0.9;compare:52.199997}
===================================
kmeans begins
kmeans running time=0ms
kmeans ends

print:cluster[0][0]={name:name3;age:3;qz:0.4;compare:28.4}
print:cluster[0][1]={name:name5;age:5;qz:0.4;compare:33.600002}
===================================
print:cluster[1][0]={name:name7;age:7;qz:0.2;compare:14.6}
===================================
print:cluster[2][0]={name:name2;age:2;qz:0.9;compare:57.6}
===================================
print:cluster[3][0]={name:name1;age:1;qz:0.9;compare:48.6}
print:cluster[3][1]={name:name9;age:9;qz:0.9;compare:52.199997}
===================================
print:cluster[4][0]={name:name0;age:0;qz:0.0;compare:0.0}
print:cluster[4][1]={name:name4;age:4;qz:0.0;compare:0.0}
print:cluster[4][2]={name:name6;age:6;qz:0.5;compare:2.0}
print:cluster[4][3]={name:name8;age:8;qz:0.6;compare:5.4}
===================================

4、说明及总结。

1)基类KmeansObject定义了一个compare,我们把它叫做比较因子,分类时只要就是对分类因子进行分类计算的。所以这个分类因子很重要,每个对象的分类因子可以具体的根据业务进行计算设置。比如我们客户端测试代码中的比较因子的计算方法是,首先给每个对象赋予一个权值qz,然后根据权值和年龄的乘积(具体计算方法根据业务定)来对人群进行分类

2)该算法中对于比较因子compare的计算是影响该算法准确性的一个很重要方面,具体表现在距离(distance方法)和误差(errorSquare方法)计算中。想要改善该算法可以从这两个方法中进行修改

3)当然,我对于这个算法的实现和应用都还是很浅。如果有什么不对或者可以改善的地方请不吝赐教

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics