【商务智能】数据挖掘–分类算法:CART算法

一.前言

前面两篇文章讲了ID3(见【商务智能】数据挖掘–分类算法:ID3)和C4.5算法(见【商务智能】数据挖掘–分类算法:C4.5),CART算法和他们一样,都属于决策树分类算法。如果看过了C4.5算法,那对于CART算法也就很容易理解了。

本文将讲解和实现一个数据通用的CART算法程序。

二.算法描述

转载请注明出处:http://jasonhan.me/blog/?p=208

CART算法主要分为两个步骤:

  1. 建立决策树
  2. 剪枝

建立决策树:

CART算法的思想和C4.5很类似。他们的不同点主要有:1.CART算法在选择最佳属性时使用的是GINI指数而不是GainRatio(信息增益率);2.CART建立的决策树是二叉树而不是像C4.5那样的多叉树。

下面引入GINI指数:

giniT

 

 

GiniT_

 

 

上图中第一个公式中:T表示某个属性,j表示数据集中目标属性值的种类数,pj表示第j类目标属性值所占比例。

上图中第二个公式中:Ni表示数据集按属性T的属性值Ti分割后的个数,N为分割前总个数,后面的gini(Ti)表示数据集按照Ti分割后的gini值。
通过一个例子来看,假设现在有数据集如下:
Gender     Class
M               Short
M               Tall
M               Medium
F               Short
F               Short
F               Medium
其中Class为目标属性,有[Short,Medium,Tall]三种取值。
现在计算Gini(Gender): Gender=M的情况下有3条数据,对应的[Short,Medium,Tall]的个数为[1,1,1],而Gender=非M得情况下有3条数据,对应的[Short,Medium,Tall]的个数为[2,1,0],那么:
Gini(Gender)=(3/6)*(gini(Gender(M)))+(3/6)*(gini(Gender(非M)))
                     =(3/6)*(1-(1/3)^2-(1/3)^2-(1/3)^2) + (3/6)*(1-(2/3)^2-(1/3)^2-(0/3)^2).
可以把Gini指数理解为数据的混乱程度,Gini指数越大,数据越混乱。我们在选择最佳属性时总是选择Gini指数最小的属性。
下面讲如何进行GINI值的计算(注意CART算法的二分性):
CART在每次选取属性进行GINI计算时,都是把属性的分裂规则看做两个。
例如:
  • 如果现在要计算离散属性A的GINI值,属性A有三种取值:M、F、H,那么A的分裂规则就有三种:[M,非M]、[F、非F]、[H、非H],这三种情况都能把数据集分为两份,分别计算这三种情况下的GINI值后,选取最小的GINI值作为属性A的GINI值;
  • 如果要计算连续属性B的GINI值,那么和C4.5类似,找到B中的各个取值,排序后计算出n-1个中点points,每个中点point按照<=point和>point都能把数据集分为两份,计算各个中点的GINI值,选取最小的GINI值作为属性B的GINI值。
建立决策树的算法整体描述为:
  • GetDecisionTree(数据集dataSet,属性集attrSet,目标属性名targetValue)
  • if(数据集为空) return null
  • if(数据集都属于同一类) return 相应的叶节点
  • for(attrSet中的每个属性)
    • 计算GINI值
  • 找到最小的GINI值对应的属性A
  • 按照A的分裂规则将dataSet分为两份:dt1,dt2
  • 递归建树:
    • GetDecisionTree(dt1,attrSet,targetValue)
    • GetDecisionTree(dt2,attrSet,targetValue)

剪枝:

在建立决策树之后,当划分过细时,往往决策树过于庞大,此时的决策树产生过拟合问题(over-fitting),我们需要对树进行剪枝。剪枝分为前剪枝和后剪枝。前剪枝:在建树过程中施加一些条件使得树不至于过度分裂,如指定一个阈值x,当树的高度>=x时不再进行分裂。后剪枝:在建完树之后对树进行修剪。

这里讲讲后剪枝技术。后剪枝有好几种方法,如:代价复杂性剪枝、最小误差剪枝、悲观误差剪枝等等。这里我们使用代价复杂性剪枝。

【以下讲解来自http://http://www.cnblogs.com/zhangchaoyang,讲得很通俗易懂,感谢原作者】

对于分类回归树中的每一个非叶子节点计算它的表面误差率增益值α。

pruneFormu2

 

 

pruneFormu1

 

 

|NTt|是子树中包含的叶子节点个数;

R(t)是节点t在被剪掉的情况下的误差代价,;r(t)是节点t的误差率;

p(t)是节点t上的数据占所有数据的比例。

R(Tt)是子树Tt的误差代价,如果该节点不被剪枝。它等于子树Tt上所有叶子节点的误差代价之和。

举个例子,比如有个非叶子节点t4如图所示:

pruneTree

 

 

 

 

 

 

 

 

 

 

已知所有的数据总共有60条,则节点t4的节点误差代价为:

pruneCal1

 

 

子树误差代价为:

pruneCal2

 

 

以t4为根节点的子树上叶子节点有3个,最终:

pruneCal3

 

 

找到α值最小的非叶子节点,令其左右孩子为NULL。当多个非叶子节点的α值同时达到最小时,取|NTt|最大的进行剪枝。

三.算法实现:

转载请注明出处:http://jasonhan.me/blog/?p=208

下面用java实现一个数据通用的CART算法系统。代码结构和C4.5很类似。

DataSet数据集类和Attribute属性类相比于C4.5中的实现并没有改变。

Node类:由于是二叉树,且涉及到剪枝,所以Node类相对于之前的C4.5中的Node类改动较大(由于代码量较大,对于和之前C4.5算法一样的函数内容进行忽略):

import java.util.ArrayList;
import java.util.Map;
import java.util.Map.Entry;

/**
 * 节点类
 * @author Jason Han
 *
 */
public class Node
{
	public Attribute attr;//属性
	public Rule rules;//分裂规则,即树枝
	public Node lchild;//左孩子
	public Node rchild;//右孩子
	public String targetValue;//目标属性值,只有叶子节点有
	public int nodeNum;//节点的层遍历编号
	private Map<String, Integer> map;//哈希表【目标属性值->个数】
	public double alpha;//错误率增益
	public int leavesCount;//叶子节点个数

	public Node(Attribute attr,Rule rules){...}

	public Node(Attribute attr,String targetValue){...}

	/**
	 * 判断是否是叶子
	 * @return
	 */
	public boolean IsLeafNode(){...}

	/**
	 * 设置目标属性个数表
	 * @param map 哈希表【目标属性->个数】
	 */
	public void SetCountsMap(Map<String,Integer> map)
	{
		this.map=map;
	}

	/**
	 * 打印树
	 */
	public void PrintTree()
	{
		this.PrintTree(this, 0, null,false);
	}

	/**
	 * 打印计算错误率后的树
	 */
	public void PrintNumberedTree()
	{
		this.PrintTree(this, 0, null, true);
	}

	/**
	 * 打印树
	 * @param root 树节点
	 * @param spaceCount 空格数,用于表示树的层次
	 * @param rule 树枝
	 * @param printNum 是否打印节点编号和错误率
	 */
	private void PrintTree(Node root,int spaceCount,String rule,boolean printNum){...}

	/**
	 * 测试未知数据
	 * @param datas 未知数据
	 * @return 结果
	 */
	public String Test(String... datas){...}

	/**
	 * 按层次遍历树,同时将非叶子节点按顺序标号,同时计算alpha值
	 */
	public void LevelOrderAndSetAlpha()
	{
		ArrayList<Node> tempList=new ArrayList<Node>();

		tempList.add(this);
		int cur=0;//当前指针
		int last=1;//每一层最后一个节点的指针
		int nodeNum=1;

		//层遍历
		while(cur<last)
		{
			Node curNode=tempList.get(cur);
			if(curNode.lchild!=null)
			{
				tempList.add(curNode.lchild);
				last++;
			}
			if(curNode.rchild!=null)
			{
				tempList.add(curNode.rchild);
				last++;
			}
			if(!curNode.IsLeafNode())
			{
				curNode.nodeNum=nodeNum++;//层遍历编号
				curNode.CalAlpha();//计算误差率增益
			}
			cur++;
		}
	}

	/**
	 * 计算错误率增益alpha
	 */
	public void CalAlpha()
	{
		if(this.map==null)
			return;

		double leafCount=0;
		double Rt=this.CalRt();
		double RTt=0;
		ArrayList<Node> leafNodes=new ArrayList<Node>();
		this.SearchLeaves(this, leafNodes);
		for(Node node : leafNodes)
		{
			RTt+=node.CalRt();
			leafCount++;
		}

		double result=(Rt-RTt)/(leafCount-1);
		this.alpha=result;
		this.leavesCount=leafNodes.size();
	}

	/**
	 * 计算Rt
	 * @return
	 */
	private double CalRt()
	{
		if(this.map==null)
			return -1;

		double sum=0;
		double maxCount=0;
		for(Entry<String, Integer> entry : this.map.entrySet())
		{
			int value=entry.getValue();
			if(value>maxCount)
				maxCount=value;
			sum+=value;
		}
		double dataRowsCount=CART.origionalDataSet.dataRows.size();
		return (sum-maxCount)/dataRowsCount;
	}

	/**
	 * 找到以node为根节点的所有叶子节点
	 * @param node 根节点
	 * @param al 叶子节点集合
	 */
	private void SearchLeaves(Node node,ArrayList<Node> al)
	{
		if(node==null)
			return;
		if(node.IsLeafNode())
			al.add(node);
		else
		{
			SearchLeaves(node.lchild,al);
			SearchLeaves(node.rchild, al);
		}
	}

	/**
	 * 剪枝
	 * @return 被剪掉的节点的编号
	 */
	public int PruneTree()
	{
		int nodeNum=0;//被剪节点的编号
		Node prunedNode=null;//被剪节点
		double minAlpha=1;//被剪节点的alpha值

		ArrayList<Node> tempList=new ArrayList<Node>();

		tempList.add(this);
		int cur=0;//当前指针
		int last=1;//每一层最后一个节点的指针

		//层遍历
		while(cur<last)
		{
			Node curNode=tempList.get(cur);
			if(curNode.lchild!=null)
			{
				tempList.add(curNode.lchild);
				last++;
			}
			if(curNode.rchild!=null)
			{
				tempList.add(curNode.rchild);
				last++;
			}
			if(!curNode.IsLeafNode())
			{
				if(curNode.alpha<minAlpha||(curNode.alpha==minAlpha&&curNode.leavesCount>prunedNode.leavesCount))
				{
					minAlpha=curNode.alpha;
					nodeNum=curNode.nodeNum;
					prunedNode=curNode;
				}
			}
			cur++;
		}

		prunedNode.lchild=null;
		prunedNode.rchild=null;

		//选取最多的目标属性值作为该节点的targetValue
		int maxCount=-1;
		for(Entry<String,Integer> entry : this.map.entrySet())
		{
			if(entry.getValue()>maxCount)
			{
				prunedNode.targetValue=entry.getKey();
				maxCount=entry.getValue();
			}
		}

		return nodeNum;
	}
}

Rule类相对于之前的C4.5有一些改变(下面两个函数的内容有些变化),因为对离散属性的处理变成了二分的方式:

 	 /**
	 * 值的比较
	 * @param obj 传入值
	 * @return 该值处于节点的第几条支路上
	 */
	public int Compare(Object obj)
	{
		if(this.str_rule==null)
		{
			if(Double.parseDouble(obj.toString())<=this.d_rule)
				return 0;
			else
				return 1;
		}
		else
		{
			return obj.toString().equals(this.str_rule)?0:1;
		}
	}

	/**
	 * 返回指定index上的rule的格式化表现
	 */
	public String GetFormatRuleString(int index)
	{
		if(this.str_rule==null)
			return index==0?("<="+this.d_rule):(">"+this.d_rule);
		else
			return index==0?this.str_rule:"!"+this.str_rule;
	}

下面是CART类,由于代码较多,只贴出重要函数,缺失函数在之前的ID3和C4.5文章中可以找到。

/**
	 * 递归获取决策树
	 * @param dataSet 数据集
	 * @return 决策树根节点
	 */
	private Node GetDecisionTreeDFS(DataSet dataSet)
	{
		//若数据集为空,则返回空节点
		if(dataSet.dataRows.size()==0)
			return null;

		//若数据集中的数据都属于同一类,则返回一个叶子节点
		if(TargetAttrIsAllSame(dataSet))
		{
			Node node = new Node(new Attribute(dataSet.targetAttribute,true),dataSet.dataRows.get(0).get(origionalDataSet.attrSet.size()));
			node.SetCountsMap(this.GetEachTargetValueCount(dataSet));//设置节点上的各个目标属性值的个数
			return node;
		}

		//找到最小Gini值的属性
		Attribute minGiniAttr=null;
		Double minGini=null;
		Rule rule=null;
		for(Attribute attrName : dataSet.attrSet)
		{
			Rule tempRule=this.GetAttrRules(dataSet, attrName);
			double gini=GetGini(dataSet, attrName,tempRule);
			if(minGini==null||gini<minGini)
			{
				minGini=gini;
				minGiniAttr=attrName;
				rule=tempRule;
			}
		}

		//生成新的属性节点
		Node node=new Node(minGiniAttr, rule);
		node.SetCountsMap(this.GetEachTargetValueCount(dataSet));

		//将数据集根据属性节点的树枝进行二分,再递归进行划分
		DataSet[] dataSets=this.SplitDataSet(dataSet, node);
		node.lchild=this.GetDecisionTreeDFS(dataSets[0]);
		node.rchild=this.GetDecisionTreeDFS(dataSets[1]);

		return node;
	}

/**
	 * 计算指定属性的值域,即属性节点的分裂规则
	 * @param dataSet
	 * @param attrName
	 * @return
	 */
	private Rule GetAttrRules(DataSet dataSet,Attribute attr)
	{
		if(attr.isDiscrete)//对于离散属性的处理
		{
			String splitPoint=null;

			//找到该属性的值域strRules
			ArrayList<String> strRules=new ArrayList<String>();
			int columnIndex=origionalDataSet.attrSet.indexOf(attr);
			for(ArrayList<String> row :dataSet.dataRows)
			{
				String value=row.get(columnIndex);
				if(!strRules.contains(value))
				{
					strRules.add(value);
				}
			}

			//对于属性的每个取值进行gini值计算
			//获取gini最小的分割点
			Double minGini=null;
			for(int i=0;i<strRules.size();i++)
			{
				Rule rule=new Rule(strRules.get(i));
				double gini=this.GetGini(dataSet, attr, rule);
				if(minGini==null||gini<minGini)
				{
					splitPoint=strRules.get(i);
					minGini=gini;
				}
			}

			Rule rule=new Rule(splitPoint);
			return rule;
		}
		else//对连续属性的处理
		{
			//找到属性的值域
			ArrayList<Double> attrValues=new ArrayList<Double>();
			int columnIndex=CART.origionalDataSet.attrSet.indexOf(attr);
			for(ArrayList<String> row : dataSet.dataRows)
			{
				double value=Double.parseDouble(row.get(columnIndex));
				attrValues.add(value);
			}

			Collections.sort(attrValues);//排序

			//找到中点集midValue
			ArrayList<Double> midValue=new ArrayList<Double>();
			Double preValue=null;
			for(int i=0;i<attrValues.size();i++)
			{
				double tempValue=attrValues.get(i);
				if(preValue!=null)
				{
					midValue.add((preValue+tempValue)/2);
				}
				preValue=tempValue;
			}

			//找出gini最小的分割点
			Double minGini=null;
			double splitPoint=0;
			for(int i=0;i<midValue.size();i++)
			{
				Rule rule=new Rule(midValue.get(i));
				double gini=this.GetGini(dataSet, attr, rule);
				if(minGini==null||gini<minGini)
				{
					splitPoint=midValue.get(i);
					minGini=gini;
				}
			}
			return new Rule(splitPoint);
		}
	}
/**
	 * 计算gini值
	 * @param map 哈希表(目标属性值->个数)
	 * @return
	 */
	private double CalculateGini(Map<String,Integer> map)
	{
		double sum=0;
		for(Entry<String,Integer> entry : map.entrySet())
		{
			sum+=entry.getValue();
		}

		if(sum==0)
			return 1;

		double result=1;
		for(Entry<String,Integer> entry : map.entrySet())
		{
			double value=entry.getValue();
			result-=Math.pow(value/sum, 2);
		}
		return result;
	}

/**
	 * 将数据集根据节点Node进行二分
	 * @param dataSet 数据集
	 * @param node 节点
	 * @return 二分后的数据集
	 */
	private DataSet[] SplitDataSet(DataSet dataSet,Node node)
	{
		DataSet[] dataSets=new DataSet[]{new DataSet(dataSet.attrSet, dataSet.targetAttribute),
				                         new DataSet(dataSet.attrSet, dataSet.targetAttribute)};

		Rule rule=node.rules;
		int columnIndex=origionalDataSet.attrSet.indexOf(node.attr);
		for(ArrayList<String> row : dataSet.dataRows)
		{
			int index=rule.Compare(row.get(columnIndex));
			if(index>=0)
			{
				dataSets[index].dataRows.add(row);
			}
		}
		return dataSets;
	}

/**
	 * 计算决策树各个节点的错误率
	 */
	public void CalNodesAlpha()
	{
		this.decisionTree.LevelOrderAndSetAlpha();
	}

	/**
	 * 进行剪枝
	 * @return 被剪的节点编号
	 */
	public int PruneTree()
	{
		 return this.decisionTree.PruneTree();
	}

最后对之前的预测身高的案例(见之前的ID3算法文章中)进行测试:

cartResult

 

 

 

 

 

 

 

 

 

 

 

 

转载请注明出处:http://jasonhan.me/blog/?p=208

1 条评论

  1. I’m curious to find out what blog platform you’re working with? I’m experiencing some small security problems with my latest website and I would like to find something more safeguarded. Do you have any recommendations?

发表评论

电子邮件地址不会被公开。 必填项已用*标注

您可以使用这些HTML标签和属性: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>