【商务智能】数据挖掘–分类算法:C4.5算法

一.前言

在上一篇文章中浅析了ID3算法(见【商务智能】数据挖掘–分类算法:ID3算法),也看到了它的一些缺陷。C4.5就是在ID3的基础上建立起来的,它对ID3进行了改进。因此,对于C4.5算法的学习需要先对ID3算法进行了解。

本文将讲解和实现一个数据通用的C4.5算法程序。

二.算法讲解

转载请注明出处:http://jasonhan.me/blog/?p=206

C4.5的改进主要在于两点:1.不仅可以处理离散属性,对于连续属性也可以进行很好的处理;2.由于ID3中的以Gain划分方式从而使得系统偏向于选择分裂规则多的属性进行划分(即选择值域打的属性进行划分),而C4.5以一种更好的方式去选择最佳属性。

下面引入信息增益率(GainRatio)的概念:

1337275667_6580

 

 

1337275689_5903

 

 

可以看到,SplitInformation(分裂信息)实际上就是S关于属性A的各个值的熵。可以把信息增益和信息增益率的关系看作是速度差和加速度的关系。

ID3以最大的Gain来选择最佳属性,而C4.5以最大的GainRatio来选择最佳属性。

C4.5算法描述为:

  • GetDecisionTree(数据集dataSet,属性集attrSet,目标属性targetAttr)
  • if(数据集为空) return null;
  • if(数据集中的数据在目标属性列上的值全是一样的) return 一个与该属性值对应的叶子节点
  • if(属性集为空) 找到数据集中占大多数的目标属性值x  return 一个与x对应的叶子节点
  • for(attrSet中每一个属性A)
    • if(A为离散属性) 计算出A的GainRatio.
    • if(A为连续属性)
      • 把A的所有取值排序,找到每两个值的中点组成n-1个中点集合pointsSet
      • for(pointSet中的每个point)
        • 计算以(<=point)和(>point)来分割属性A情况下的Gain值
      • 找到最大的Gain值对应的point作为A的分割点,并计算GainRatio
  • 找到最大GainRatio对应的那个属性A
  • 根据属性A的分裂规则得到新的数据集newDataSet
  • for(属性A的每个分裂规则(即树枝))
    • 继续进行分类GetDecisionTree(newDataSet,attrSet-A,targetAttr)

重点是上文中对于连续属性的处理方式,C4.5对于连续属性的分割点是系统自动计算的,而不是像ID3那样去随便将连续属性分段。

三.算法实现:

转载请注明出处:http://jasonhan.me/blog/?p=206

下面实现一个对任何数据都通用的一个C4.5算法系统。

先实现一个属性类(Attribute),可以看到,我们在ID3算法中对于属性的保存只是一个String,而在C4.5中由于我们要区分开连续属性和离散属性,所以对属性类增加一项:isDiscrete,用于保存属性是连续的还是离散的。

/**
 * 属性类
 * @author Jason Han
 *
 */
public class Attribute
{
	public String attrName;//属性名称
	public boolean isDiscrete;//该属性是否是离散型属性
	public Attribute(String attrName,boolean isDiscreate)
	{
		this.attrName=attrName;
		this.isDiscrete=isDiscreate;
	}
}

新增了属性类,那么数据集类(DataSet)的类成员结构也要做相应的一点点改变:

public ArrayList<Attribute> attrSet;//属性集
public ArrayList<ArrayList<String>> dataRows;//数据表
public String targetAttribute;//目标属性名

同时我们需要在ID3基础上新增一个分裂规则类(Rule),即用于保存节点的分裂规则,也就是树里面的树枝。因为在ID3中的属性分裂规则都是ArrayList<String>,而C4.5中由于要处理离散属性,所以分裂规则就较之前的复杂一点。

import java.util.ArrayList;

/**
 * 节点分裂规则类,即树枝
 * @author Jason Han
 *
 */
public class Rule
{
	public ArrayList<String> str_rules;//离散属性的分裂规则
	public double d_rule;//连续属性的分裂规则

	public Rule(ArrayList<String> str_rules)
	{
		this.str_rules=str_rules;
	}

	public Rule(double d_rule)
	{
		this.d_rule=d_rule;
		this.str_rules=null;
	}

	/**
	 * 值的比较
	 * @param obj 传入值
	 * @return 该值处于节点的第几条支路上
	 */
	public int Compare(Object obj)
	{
		if(this.str_rules==null)
		{
			if(Double.parseDouble(obj.toString())<=this.d_rule)
				return 0;
			else
				return 1;
		}
		else
		{
			for(int i=0;i<this.str_rules.size();i++)
			{
				if(obj.toString().equals(this.str_rules.get(i)))
					return i;
			}
			return -1;
		}
	}

	/**
	 * 返回指定index上的rule的格式化表现,便于计算
	 */
	public String GetFormatRuleString(int index)
	{
		if(this.str_rules==null)
			return index==0?("<="+this.d_rule):(">"+this.d_rule);
		else
			return this.str_rules.get(index);
	}
}

树节点类(Node)的结构也需要做一点点变化:

public Attribute attr;//属性
public Rule rules;//属性分裂规则,即树枝
public ArrayList<Node> children;//子节点
public String targetValue;//目标属性值,只有叶子节点有

下面是C4.5类(C45),由于代码较多,只摘取主要的函数,对于缺少的函数代码,在之前ID3算法的讲解文章中可以找到。

/**
	 * 递归获取决策树
	 * @param dataSet 数据集
	 * @return 决策树根节点
	 */
	private Node GetDecisionTreeDFS(DataSet dataSet)
	{
		//如果数据集为空,返回空节点
		if(dataSet.dataRows.size()==0)
			return null;

		//如果所有数据都属于同一类,则返回一个叶节点
		if(TargetAttrIsAllSame(dataSet))
			return new Node(new Attribute(dataSet.targetAttribute,true),dataSet.dataRows.get(0).get(origionalDataSet.attrSet.size()));

		//如果数据集属性集为空,返回一个对应于大多数种类的叶节点
		if(dataSet.attrSet.size()<=0)
			return new Node(new Attribute(dataSet.targetAttribute,false),this.GetMajorTargetValue(dataSet));

		//获取最佳属性
		Attribute maxGainAttr=null;
		Double maxGainRatio=null;
		Rule rule=null;
		for(Attribute attrName : dataSet.attrSet)
		{
			Rule tempRule=this.GetAttrRules(dataSet, attrName);
			double gainRatio=GetGainRatio(dataSet, attrName,tempRule);
			if(maxGainRatio==null||gainRatio>maxGainRatio)
			{
				maxGainRatio=gainRatio;
				maxGainAttr=attrName;
				rule=tempRule;
			}
		}

		Node node=new Node(maxGainAttr, rule);//生成一个属性节点
		//节点分裂,对属性的每个取值进行再次划分
		for(int i=0;i<(node.attr.isDiscrete?node.rules.str_rules.size():2);i++)
		{
			//获取新的数据集
			ArrayList<Attribute> newAttrSet=new ArrayList<Attribute>();
			for(Attribute attr : dataSet.attrSet)
			{
				if(!attr.attrName.equals(maxGainAttr.attrName))
					newAttrSet.add(attr);
			}

			//获取新的数据集
			DataSet newDataSet=FindSpecificDT(dataSet, maxGainAttr, node.rules.GetFormatRuleString(i));
			newDataSet.attrSet=newAttrSet;

			//递归划分
			node.children.add(GetDecisionTreeDFS(newDataSet));
		}

		return node;
	}

	/**
	 * 计算指定属性的值域,即属性节点的分裂规则
	 * @param dataSet 数据集
	 * @param attrName 属性名
	 * @return 节点规则,即树枝
	 */
	private Rule GetAttrRules(DataSet dataSet,Attribute attr)
	{
		if(attr.isDiscrete)//对离散属性的处理,和之前的ID3算法一致
		{
			ArrayList<String> result=new ArrayList<>();
			int columnIndex=origionalDataSet.attrSet.indexOf(attr);
			for(ArrayList<String> row :dataSet.dataRows)
			{
				String value=row.get(columnIndex);
				if(!result.contains(value))
				{
					result.add(value);
				}
			}
			Rule rule=new Rule(result);
			return rule;
		}
		else //对连续属性的划分
		{
			//获取该列中的各个属性值
			ArrayList<Double> attrValues=new ArrayList<Double>();
			int columnIndex=C45.origionalDataSet.attrSet.indexOf(attr);
			for(ArrayList<String> row : dataSet.dataRows)
			{
				double value=Double.parseDouble(row.get(columnIndex));
				attrValues.add(value);
			}

			Collections.sort(attrValues);//排序

			//计算各个中点
			ArrayList<Double> midValue=new ArrayList<Double>();
			Double preValue=null;
			for(int i=0;i<attrValues.size();i++)
			{
				double tempValue=attrValues.get(i);
				if(preValue!=null)
				{
					midValue.add((preValue+tempValue)/2);
				}
				preValue=tempValue;
			}

			//找出gain最大的分割点
			Double maxGain=null;
			double splitPoint=0;
			for(int i=0;i<midValue.size();i++)
			{
				Rule rule=new Rule(midValue.get(i));
				double gain=this.GetGain(dataSet, attr, rule);
				if(maxGain==null||gain>maxGain)
				{
					splitPoint=midValue.get(i);
					maxGain=gain;
				}
			}
			return new Rule(splitPoint);
		}
	}

	/**
	 * 获取熵
	 * @param dataSet 数据集
	 * @param attrName 属性名
	 * @param rules 属性的分裂规则
	 * @return 熵
	 */
	private double GetEntropy(DataSet dataSet,Attribute attrName,Rule rule)
	{
		if(attrName==null)//计算dataSet的熵
		{
			Map<String,Integer> map=this.GetEachTargetValueCount(dataSet);
			return CalculateEntropy(map);
		}
		else//计算attrName对应的dataSet的熵
		{
			double result=0.0;
			int rulesCount=attrName.isDiscrete?rule.str_rules.size():2;
			for(int i=0;i<rulesCount;i++)//对于attrName的每个分裂规则进行计算
			{
				Map<String,Integer> map=GetEachTargetValueCount(dataSet, attrName, rule.GetFormatRuleString(i));
				double entropy=this.CalculateEntropy(map);
				double sum=0.0;
				for(Entry<String,Integer> entry : map.entrySet())
				{
					sum+=entry.getValue();
				}
				double dtSize=dataSet.dataRows.size();

				result+=sum/dtSize*entropy;
			}

			return result;
		}
	}

	/**
	 * 获取Gain
	 * @param dataSet 数据集
	 * @param attrName 属性
	 * @param rule 属性的分裂规则,即树枝
	 * @return
	 */
	private double GetGain(DataSet dataSet,Attribute attr,Rule rule)
	{
		return GetEntropy(dataSet, null,null) - GetEntropy(dataSet, attr,rule);
	}

	/**
	 * 获取GainRatio
	 * @param dataSet 数据集
	 * @param attr 属性
	 * @param rule 属性的分裂规则,即树枝
	 * @return
	 */
	private double GetGainRatio(DataSet dataSet,Attribute attr,Rule rule)
	{
		return GetGain(dataSet,attr,rule)/GetSplitInfo(dataSet, attr, rule);
	}

	/**
	 * 获取SplitInfo
	 * @param dataSet 数据集
	 * @param attr 属性
	 * @param rule 属性的分裂规则,即树枝
	 * @return
	 */
	private double GetSplitInfo(DataSet dataSet,Attribute attr,Rule rule)
	{
		double[] counts=this.FindSpecificDTCount(dataSet, attr, rule);
		double sum=dataSet.dataRows.size();
		double result=0.0;
		for(int i=0;i<counts.length;i++)
		{
			if(counts[i]!=0)
				result+=-(counts[i]/sum)*Math.log(counts[i]/sum)/Math.log(2.0);
		}
		return result;
	}

	/**
	 * 计算熵
	 * @param map 目标属性的值->该值的个数
	 * @return 熵
	 */
	private double CalculateEntropy(Map<String,Integer> map)
	{
		double sum=0.0;
		for(Entry<String, Integer> entry : map.entrySet())
		{
			sum+=entry.getValue();
		}

		double result=0.0;
		for(Map.Entry<String, Integer> entry : map.entrySet())
		{
			int value=entry.getValue();
			if(value==0)
				continue;
			result+=-((double)value/sum)*Math.log((double)value/sum)/Math.log(2.0);
		}
		return result;
	}

	/**
	 * 获取各个目标属性值对应的数量
	 * @return Map: 目标属性值-->数量
	 */
	private Map<String,Integer> GetEachTargetValueCount(DataSet dataSet)
	{
		Map<String,Integer> result=new Hashtable<String, Integer>();
		for(ArrayList<String> data : dataSet.dataRows)
		{
			String targetValue=data.get(origionalDataSet.attrSet.size());
			if(result.containsKey(targetValue))
				result.put(targetValue, result.get(targetValue)+1);
			else
				result.put(targetValue, 1);
		}
		return result;
	}

	/**
	 * 获取特定的属性的特定值所对应的dataset中目标属性的数量
	 * @param dataSet 数据集
	 * @param attrName 属性
	 * @param value 属性值
	 * @return Map: 目标属性值-->数量
	 */
	private Map<String,Integer> GetEachTargetValueCount(DataSet dataSet,Attribute attr,String value)
	{
		Double splitPoint=null;
		boolean isLessThan=false;
		if(!attr.isDiscrete)//对连续属性的处理
		{
			isLessThan=value.startsWith("<");
			if(isLessThan)
			{
				splitPoint=Double.parseDouble(value.substring(2));
			}
			else
			{
				splitPoint=Double.parseDouble(value.substring(1));
			}
		}

		Map<String,Integer> map=new Hashtable<String, Integer>();
		int columnIndex=origionalDataSet.attrSet.indexOf(attr);
		for(int i=0;i<dataSet.dataRows.size();i++)
		{
			if(attr.isDiscrete&&!dataSet.dataRows.get(i).get(columnIndex).equals(value))
				continue;
			else if(!attr.isDiscrete)
			{
				double attrValue=Double.parseDouble(dataSet.dataRows.get(i).get(columnIndex));
				if((isLessThan&&attrValue>splitPoint)||(!isLessThan&&attrValue<=splitPoint))
					continue;
			}

			String targetValue=dataSet.dataRows.get(i).get(origionalDataSet.attrSet.size());
			if(map.containsKey(targetValue))
				map.put(targetValue, map.get(targetValue)+1);
			else
				map.put(targetValue, 1);

		}
		return map;
	}

	/**
	 * 获取:在数据集中,属性attr的各个取值对应的数据条数
	 * @param dataSet
	 * @param attr
	 * @param rule
	 * @return
	 */
	private double[] FindSpecificDTCount(DataSet dataSet,Attribute attr,Rule rule)
	{
		double[] result=new double[rule.str_rules!=null?rule.str_rules.size():2];
		int columnIndex=C45.origionalDataSet.attrSet.indexOf(attr);
		for(ArrayList<String> row : dataSet.dataRows)
		{
			result[rule.Compare(row.get(columnIndex))]++;
		}
		return result;
	}

	/**
	 * 获取特定的属性值对应的数据集
	 * @param dataSet 数据集
	 * @param attr 属性
	 * @param value 属性值
	 * @return 新数据集
	 */
	private DataSet FindSpecificDT(DataSet dataSet,Attribute attr,String value)
	{
		Double splitPoint=null;
		boolean isLessThan=false;
		if(!attr.isDiscrete)//对连续属性的处理
		{
			isLessThan=value.startsWith("<");
			if(isLessThan)
			{
				splitPoint=Double.parseDouble(value.substring(2));
			}
			else
			{
				splitPoint=Double.parseDouble(value.substring(1));
			}
		}

		DataSet result=new DataSet(null,origionalDataSet.targetAttribute);
		int columnIndex=origionalDataSet.attrSet.indexOf(attr);
		for(ArrayList<String> row : dataSet.dataRows)
		{
			if(attr.isDiscrete&&row.get(columnIndex).equals(value))
			{
				result.AddRow(row);
			}
			else if(!attr.isDiscrete)
			{
				double attrValue=Double.parseDouble(row.get(columnIndex));
				if((isLessThan&&attrValue<=splitPoint)||(!isLessThan&&attrValue>splitPoint))
					result.AddRow(row);
			}
		}
		return result;
	}

可以看到代码和ID3算法的出入不大,重点是在每一步计算的时候都要将离散属性和连续属性区别对待,还有就是需要计算GainRatio而不仅仅是Gain。

我们对之前的预测人的身高类型的案例进行测试,测试结果为:

C45testresult

 

 

 

 

对打高尔夫球案例测试结果:

C45testresult2

 

 

 

 

 

 

转载请注明出处:http://jasonhan.me/blog/?p=206

1 条评论

  1. Im no professional, but I feel you just made the best point. You obviously fully understand what youre talking about, and I can truly get behind that. Thanks for staying so upfront and so straightforward.

发表评论

电子邮件地址不会被公开。 必填项已用*标注

您可以使用这些HTML标签和属性: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>