【商务智能】数据挖掘–分类算法:ID3算法

一.前言

ID3和C4.5算法属于分类算法,对于分类和聚类的概念介绍可见之前的一篇文章:【商务智能】浅谈数据挖掘分类和聚类算法

前面讲了贝叶斯分类算法,ID3和C4.5和贝叶斯最大的不同是他们属于决策树分类,在算法过程中需要生成一颗决策树。

再使用之前高尔夫球场的例子,给出数据集:

68ffc7a44afbc8155ae65&690

 

 

 

 

 

 

 

 

 

经过算法可生成一颗决策树如下:

golfDTree

 

 

 

 

 

 

 

 

在测试数据时,使用未知数据对树进行遍历,直到达到一个叶节点。

本文将讲解和实现一个数据通用的ID3算法。

二.ID3算法

转载请注明出处: http://jasonhan.me/blog/?p=196

1.算法讲解:

算法的描述为:

  • GetDecisionTree(数据集,属性集,目标属性)
  • if(数据集为空) return null;
  • if(数据集中的数据在目标属性列上的值全是一样的) return 一个与该属性值对应的叶子节点
  • if(属性集为空) 找到数据集中占大多数的目标属性值x  return 一个与x对应的叶子节点
  • 找到一个【最好的】属性A,生成一个与A对应的节点
  • 对于A的每个取值value,找到满足A的值=value的数据集newDataSet
  • GetDecisionTree(newDataSet,属性集-A,目标属性)

算法的思路就是这样,但是如何找上述的【最好的】属性A呢?下面引入熵的概念:

0_1326017726zJsF

 

 

(图片来自网络,非原创)

其中pi表示样例中各种类型所占的比例,例如对于之前说的判断人身高类型的案例(见之前的贝叶斯分类),如果tall类型有1个,medium类型有2个,short有3个,那么Entropy(S)=-(1/6)log2(1/6)-(2/6)log2(2/6)-(3/6)log2(3/6)。

信息论中对熵的一种解释,熵确定了要编码集合S中任意成员的分类所需要的最少二进制位数。将概念挪到分类算法上,这里的熵可以衡量样例的纯度。
下面引入信息增益的概念:
 0_133094595205ru
(图片来自网络,非原创)
一个属性的信息增益就是由于使用这个属性分割样例而导致的熵降低。说通俗点,Gain(S,A)就是指,如果我们用属性A来分割数据集,那么熵可以减低多少,即样例的纯度能增加多少。我们选取Gain值最大的属性A作为当前节点进行分裂。

2.实现:

下面使用java来实现一个对任何数据都通用的ID3算法系统。(在写的过程中深切感觉到如果使用C#,会大大减少代码量,因为C#的linq技术可以实现对数据集进行各种类似sql语言式的检索)。

数据集结构(DataSet):

public ArrayList<String> attrSet;//属性集
public ArrayList<ArrayList<String>> dataRows;//数据表
public String targetAttribute;//目标属性

树节点类(Node):

import java.util.ArrayList;

/**
 * 树节点类
 * @author Jason Han
 *
 */
public class Node 
{
	public String attrName;//属性名
	public ArrayList<String> rules;//分裂规则,即树枝
	public ArrayList<Node> children;//子节点集合
	public String targetValue;//目标属性值,只有叶子节点有

	public Node(String attrName,ArrayList<String> rules)
	{
		this.attrName=attrName;
		this.rules=rules;
		this.targetValue=null;
		this.children=new ArrayList<Node>();
	}

	public Node(String attrName,String targetValue)
	{
		this.attrName=attrName;
		this.rules=null;
		this.targetValue=targetValue;
		this.children=new ArrayList<Node>();
	}

	public void PrintTree()
	{
		this.PrintTree(this, 0, null);
	}

	/**
	 * 遍历打印树结构
	 * @param root 根节点
	 * @param spaceCount 空格数,便于区分树的层次
	 * @param rule 父节点规则,即树枝
	 */
	private void PrintTree(Node root,int spaceCount,String rule)
	{
		if(root==null)
			return;

		for(int i=0;i<spaceCount;i++)
		{
			System.out.print(" ");
		}

		if(root.targetValue!=null)
			System.out.println((rule!=null?rule+":":"")+root.targetValue+"(leaf)");
		else
			System.out.println((rule!=null?rule+":":"")+root.attrName);

		if(root.children!=null&&root.children.size()>0)
		{
			for(int i=0;i<root.children.size();i++)
			{
				PrintTree(root.children.get(i),spaceCount+2,root.rules.get(i));
			}
		}
	}

	/**
	 * 遍历决策树,对未知数据进行测试
	 * @param datas 未知数据
	 * @return 测试结果
	 */
	public String Test(String... datas)
	{
		if(datas.length!=ID3.origionalDataSet.attrSet.size())
		{
			System.out.print("数据不完整,测试失败");
			return "";
		}

		Node node=this;
		while(node!=null)
		{
			if(node.targetValue!=null)
				return node.targetValue;

			String attrName=node.attrName;
			int columnIndex=ID3.origionalDataSet.attrSet.indexOf(attrName);
			boolean testRight=false;
			for(String rule : node.rules)
			{
				if(rule.equals(datas[columnIndex]))
				{
					node=node.children.get(node.rules.indexOf(rule));
					testRight=true;
					break;
				}
			}
			if(!testRight)
				break;
		}

		return null;
	}
}

ID3类:

import java.util.ArrayList;
import java.util.Hashtable;
import java.util.Map;
import java.util.Map.Entry;

/**
 * ID3类
 * @author Jason Han
 *
 */
public class ID3 
{
	public static DataSet origionalDataSet;//最初的数据集

	public ID3(DataSet dataSet)
	{
		origionalDataSet=dataSet;
	}

	/**
	 * 获取决策树
	 * @return 决策树根节点
	 */
	public Node GetDecisionTree()
	{
		return this.GetDecisionTreeDFS(origionalDataSet);
	}

	/**
	 * 递归获取决策树
	 * @param dataSet 数据集
	 * @return 决策树根节点
	 */
	private Node GetDecisionTreeDFS(DataSet dataSet)
	{
		if(dataSet.dataRows.size()==0)
			return null;

		//如果剩下的数据都是同一类,则返回一个叶子节点
		if(TargetAttrIsAllSame(dataSet))
			return new Node(dataSet.targetAttribute,dataSet.dataRows.get(0).get(origionalDataSet.attrSet.size()));

		//如果属性集已经为空了,则获取数据集中的大多数种类来组成一个叶节点
		if(dataSet.attrSet.size()<=0)
			return new Node(dataSet.targetAttribute,this.GetMajorTargetValue(dataSet));

		//寻找具有最大的Gain值的属性
		String maxGainAttrName=null;
		double maxGain=-1;
		ArrayList<String> rules=new ArrayList<String>();//该属性值的分裂规则,即各个树枝
		for(String attrName : dataSet.attrSet)
		{
			ArrayList<String> tempRules=this.GetAttrRules(dataSet, attrName);
			double gain=GetGain(dataSet, attrName,tempRules);
			if(gain>maxGain)
			{
				maxGain=gain;
				maxGainAttrName=attrName;
				rules.clear();
				rules.addAll(tempRules);
			}
		}

		Node node=new Node(maxGainAttrName, rules);//生成一个新的节点
		for(int i=0;i<node.rules.size();i++)//对于每个树枝再继续分类
		{
			//获取新的属性集
			ArrayList<String> newAttrSet=new ArrayList<String>();
			for(String attr : dataSet.attrSet)
			{
				if(attr!=maxGainAttrName)
					newAttrSet.add(attr);
			}

			//获取新的数据集
			DataSet newDataSet=FindSpecificDT(dataSet, maxGainAttrName, node.rules.get(i));
			newDataSet.attrSet=newAttrSet;

			//继续分类
			node.children.add(GetDecisionTreeDFS(newDataSet));
		}

		return node;
	}

	/**
	 * 计算指定属性的值域,即属性节点的分裂规则
	 * @param dataSet
	 * @param attrName
	 * @return
	 */
	private ArrayList<String> GetAttrRules(DataSet dataSet,String attrName)
	{
		ArrayList<String> result=new ArrayList<>();
		int columnIndex=origionalDataSet.attrSet.indexOf(attrName);
		for(ArrayList<String> row :dataSet.dataRows)
		{
			String value=row.get(columnIndex);
			if(!result.contains(value))
			{
				result.add(value);
			}
		}
		return result;
	}

	/**
	 * 获取熵值
	 * @param dataSet 数据集
	 * @param attrName 属性名
	 * @param rules 属性分裂规则,即属性名
	 * @return 熵
	 */
	private double GetEntropy(DataSet dataSet,String attrName,ArrayList<String> rules)
	{		
		if(attrName==null)
		{
			Map<String,Integer> map=this.GetEachTargetValueCount(dataSet);
			return CalculateEntropy(map);
		}
		else
		{
			double result=0.0;

			for(int i=0;i<rules.size();i++)
			{
				Map<String,Integer> map=GetEachTargetValueCount(dataSet, attrName, rules.get(i));
				double entropy=this.CalculateEntropy(map);
				double sum=0.0;
				for(Entry<String,Integer> entry : map.entrySet())
				{
					sum+=entry.getValue();
				}
				double dtSize=dataSet.dataRows.size();

				result+=sum/dtSize*entropy;
			}
			return result;
		}
	}

	/**
	 * 获取Gain值
	 * @param dataSet 数据集
	 * @param attrName 属性名
	 * @param rules 属性的分裂规则,即属性的值域
	 * @return Gian值
	 */
	private double GetGain(DataSet dataSet,String attrName,ArrayList<String> rules)
	{
		return GetEntropy(dataSet, null,null) - GetEntropy(dataSet, attrName,rules);
	}

	/**
	 * 计算熵
	 * @param map 目标属性的值->个数
	 * @return 熵
	 */
	private double CalculateEntropy(Map<String,Integer> map)
	{
		double sum=0.0;
		for(Entry<String, Integer> entry : map.entrySet())
		{
			sum+=entry.getValue();
		}

		double result=0.0;
		for(Map.Entry<String, Integer> entry : map.entrySet())
		{
			int value=entry.getValue();
			if(value==0)
				continue;
			result+=-((double)value/sum)*Math.log((double)value/sum)/Math.log(2.0);
		}
		return result;
	}

	/**
	 * 获取各个目标属性值对应的数量
	 * @return 哈希表【 目标属性的值->数量】
	 */
	private Map<String,Integer> GetEachTargetValueCount(DataSet dataSet)
	{
		Map<String,Integer> result=new Hashtable<String, Integer>();
		for(ArrayList<String> data : dataSet.dataRows)
		{
			String targetValue=data.get(origionalDataSet.attrSet.size());
			if(result.containsKey(targetValue))
				result.put(targetValue, result.get(targetValue)+1);
			else
				result.put(targetValue, 1);
		}
		return result;
	}

	/**
	 * 获取特定的属性的特定值所对应的dataset中目标属性的数量
	 * @param dataSet 数据集
	 * @param attrName 属性名
	 * @param value 属性值
	 * @return 哈希表【目标属性的值->数量】
	 */
	private Map<String,Integer> GetEachTargetValueCount(DataSet dataSet,String attrName,String value)
	{
		Map<String,Integer> map=new Hashtable<String, Integer>();
		int columnIndex=origionalDataSet.attrSet.indexOf(attrName);
		for(int i=0;i<dataSet.dataRows.size();i++)
		{
			if(dataSet.dataRows.get(i).get(columnIndex).equals(value))
			{
				String targetValue=dataSet.dataRows.get(i).get(origionalDataSet.attrSet.size());
				if(map.containsKey(targetValue))
					map.put(targetValue, map.get(targetValue)+1);
				else
					map.put(targetValue, 1);
			}
		}
		return map;
	}

	/**
	 * 获取属性名为attr且该属性的值为value对应的数据集
	 * @param dataSet 数据集
	 * @param attr 属性名
	 * @param value 属性值
	 * @return 新的数据集
	 */
	private DataSet FindSpecificDT(DataSet dataSet,String attr,String value)
	{
		DataSet result=new DataSet(null,origionalDataSet.targetAttribute);
		int columnIndex=origionalDataSet.attrSet.indexOf(attr);
		for(ArrayList<String> row : dataSet.dataRows)
		{
			if(row.get(columnIndex).equals(value))
			{
				result.AddRow(row);
			}
		}
		return result;
	}

	/**
	 * 判断数据集中的数据是否都属于同一类
	 * @param dataSet 数据集
	 * @return 布尔结果
	 */
	private boolean TargetAttrIsAllSame(DataSet dataSet)
	{
		String tempValue=null;
		for(ArrayList<String> row : dataSet.dataRows)
		{
			String targetValue=row.get(origionalDataSet.attrSet.size());
			if(tempValue==null)
			{
				tempValue=targetValue;
				continue;
			}

			if(!tempValue.equals(targetValue))
				return false;
		}
		return true;
	}

	/**
	 * 获取数据集中的占大多数的目标属性值
	 * @param dataSet 数据集
	 * @return maxCount目标属性值
	 */
	private String GetMajorTargetValue(DataSet dataSet)
	{
		String majorTargetValue=null;
		int maxCount=-1;
		Map<String,Integer> map=this.GetEachTargetValueCount(dataSet);
		for(Entry<String,Integer> entry : map.entrySet())
		{
			if(entry.getValue()>maxCount)
			{
				majorTargetValue=entry.getKey();
			}
		}
		return majorTargetValue;
	}
}

可以看出ID3算法效率不高,需要对数据集进行多次检索。不过作为最基本的决策树算法,为之后的其他决策树算法奠定了基础。

下面对之前贝叶斯分类中的身高分类案例进行测试:

Name Gender Height Class
Kristina F 1.6m Short
Jim M 2m Tall
Maggie F 1.9m Medium
Martha F 1.88m Medium
Stephanie F 1.7m Short
Bob M 1.85m Medium
Kathy F 1.6m Short
Dave M 1.7m Short
Worth M 2.2m Tall
Steven M 2.1m Tall
Debbie F 1.8m Medium
Todd M 1.95m Medium
Kim F 1.9m Medium
Amy F 1.8m Medium
Wynette F 1.75m Medium

结果为:

ID3testresult

 

 

 

 

 

细心的朋友可以发现,测试的结果并不是我们想象的那样,应该是Medium才正确。为什么会这样?问题就出在算法中的有一句:

  • if(属性集为空) 找到数据集中占大多数的目标属性值x  return 一个与x对应的叶子节点

我们之前将身高划分为了6个段,当分到5的时候,剩下了[M,5,tall]和[M,5,medium],我们在取大多数的时候正好取到了tall而没有取到medium,所以导致测试结果错误。这是算法本身的缺陷。

同时,你会发现当你对身高进行不同的划分时(比如划分为3个段),那么结果又不一样了。这似乎不科学。这也是ID3算法的一大缺陷,就是不能处理连续性属性。这里的身高是一个连续属性(1.6~2.2),所以人为的划分当然会影响结果。

为了解决这一缺陷,后来诞生了C4.5算法。下一篇文章将介绍C4.5算法。

转载请注明出处: http://jasonhan.me/blog/?p=196

发表评论

电子邮件地址不会被公开。 必填项已用*标注

您可以使用这些HTML标签和属性: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>