Decision tree picking out good watermelon: pure algorithm

1, Theoretical knowledge

  1. purity
    For a branch node, if the samples contained in the node belong to the same category, its purity is 1, and we always hope that the higher the purity, the better, that is, as many samples belong to the same category as possible. So how to measure "purity"? Therefore, the concept of "information entropy" is introduced.

  2. information entropy
    Assuming that the proportion of class k samples in the current sample set D is pk (k=1,2,..., |y|), the information entropy of D is defined as:

         Ent(D) = -∑k=1 pk·log2 pk    (Agreed if p=0,be log2 p=0)
    

Obviously, the smaller the Ent(D) value, the higher the purity of D. Because 0 < = PK < = 1, log2 PK < = 0, Ent(D) > = 0. In the limit case, considering that the samples in D belong to the same class, the Ent(D) value at this time is 0 (the minimum value is taken). When the samples in D belong to different categories, Ent(D) takes the maximum value log2 |y |

  1. information gain
    Suppose that the discrete attribute a has V possible values {a1,a2,..., av}. If a is used to classify the sample set D, V branch nodes will be generated, and Dv is recorded as all the samples in D contained in the V branch node with av value on attribute a. The number of samples of different branch nodes is different, and we give different weights to branch nodes: |Dv|/|D|, which gives greater influence to branch nodes with more samples. Therefore, the information gain obtained by dividing sample set D with attribute a is defined as:

       Gain(D,a) = Ent(D)-∑v=1 |Dv|/|D|·Ent(Dv)
    

Where Ent(D) is the information entropy before data set D division, and Σ v=1 |Dv|/|D | · Ent(Dv) can be expressed as the information entropy after division. The "front back" results show the reduction of information entropy obtained by this division, that is, the improvement of purity. Obviously, the greater the Gain(D,a), the greater the purity improvement, and the better the effect of this division.

  1. Gain ratio
    The optimal attribute division principle based on information gain - information gain criterion has a preference for attributes with more available data. C4.5 algorithm uses gain rate instead of information gain to select the optimal partition attribute. The gain rate is defined as:

       Gain_ratio(D,a) = Gain(D,a)/IV(a)
    

among

       IV(a) = -∑v=1 |Dv|/|D|·log2 |Dv|/|D|

It is called the intrinsic value of attribute a. The greater the number of possible values of attribute a (i.e. the greater V), the greater the value of IV(a). This eliminates the preference for attributes with more value data to a certain extent.

In fact, the gain rate criterion has a preference for attributes with a small number of values. C4.5 algorithm does not directly use the gain rate criterion, but first finds the attributes with higher information gain than the average level from the candidate partition attributes, and then selects the attribute with the highest gain rate.

  1. Gini index
    CART decision tree algorithm uses Gini index to select partition attributes. Gini index is defined as:

      Gini(D) = ∑k=1 ∑k'≠1 pk·pk' = 1- ∑k=1  pk·pk
    

Gini index can be understood as follows: the probability that two samples are randomly selected from data set D and their category labels are inconsistent. The smaller Gini(D), the higher the purity.

Definition of Gini index for attribute a:

 Gain_index(D,a) = ∑v=1 |Dv|/|D|·Gini(Dv)

The Gini index is used to select the optimal partition attribute, that is, the attribute that minimizes the Gini index after partition is selected as the optimal partition attribute.

2, Code implementation

1. Import data and required packages

import numpy as np
import pandas as pd
import sklearn.tree as st
import math
data = pd.read_csv('./Watermelon dataset.csv')
data

2. Function

(1) Computational entropy

def calcEntropy(dataSet):
    mD = len(dataSet)
    dataLabelList = [x[-1] for x in dataSet]
    dataLabelSet = set(dataLabelList)
    ent = 0
    for label in dataLabelSet:
        mDv = dataLabelList.count(label)
        prop = float(mDv) / mD
        ent = ent - prop * np.math.log(prop, 2)

    return ent

(2) Split dataset

# Index - the index of the feature to split
# Feature - the feature to split
# Return value - a set in which the feature of the index in the dataSet is feature and the index column is removed
def splitDataSet(dataSet, index, feature):
    splitedDataSet = []
    mD = len(dataSet)
    for data in dataSet:
        if(data[index] == feature):
            sliceTmp = data[:index]
            sliceTmp.extend(data[index + 1:])
            splitedDataSet.append(sliceTmp)
    return splitedDataSet

(3) Select the best feature

# Return value - the subscript of the best feature
def chooseBestFeature(dataSet):
    entD = calcEntropy(dataSet)
    mD = len(dataSet)
    featureNumber = len(dataSet[0]) - 1
    maxGain = -100
    maxIndex = -1
    for i in range(featureNumber):
        entDCopy = entD
        featureI = [x[i] for x in dataSet]
        featureSet = set(featureI)
        for feature in featureSet:
            splitedDataSet = splitDataSet(dataSet, i, feature)  # Split dataset
            mDv = len(splitedDataSet)
            entDCopy = entDCopy - float(mDv) / mD * calcEntropy(splitedDataSet)
        if(maxIndex == -1):
            maxGain = entDCopy
            maxIndex = i
        elif(maxGain < entDCopy):
            maxGain = entDCopy
            maxIndex = i

    return maxIndex

(4) Find the most as a label

# Return value - label
def mainLabel(labelList):
    labelRec = labelList[0]
    maxLabelCount = -1
    labelSet = set(labelList)
    for label in labelSet:
        if(labelList.count(label) > maxLabelCount):
            maxLabelCount = labelList.count(label)
            labelRec = label
    return labelRec

(5) Spanning tree

def createFullDecisionTree(dataSet, featureNames, featureNamesSet, labelListParent):
    labelList = [x[-1] for x in dataSet]
    if(len(dataSet) == 0):
        return mainLabel(labelListParent)
    elif(len(dataSet[0]) == 1): #There are no separable properties
        return mainLabel(labelList)  #Select the most label as the label of the dataset
    elif(labelList.count(labelList[0]) == len(labelList)): # All belong to the same Label
        return labelList[0]

    bestFeatureIndex = chooseBestFeature(dataSet)
    bestFeatureName = featureNames.pop(bestFeatureIndex)
    myTree = {bestFeatureName: {}}
    featureList = featureNamesSet.pop(bestFeatureIndex)
    featureSet = set(featureList)
    for feature in featureSet:
        featureNamesNext = featureNames[:]
        featureNamesSetNext = featureNamesSet[:][:]
        splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)
        myTree[bestFeatureName][feature] = createFullDecisionTree(splitedDataSet, featureNamesNext, featureNamesSetNext, labelList)
    return myTree

(6) Initialization

# Return value
# dataSet dataset
# featureNames label
# featureNamesSet column label
def readWatermelonDataSet():
    dataSet = data.values.tolist()
    featureNames =['color and lustre', 'Root', 'Knock', 'texture', 'Umbilicus', 'Tactile sensation']
    #Get featureNamesSet
    featureNamesSet = []
    for i in range(len(dataSet[0]) - 1):
        col = [x[i] for x in dataSet]
        colSet = set(col)
        featureNamesSet.append(list(colSet))
    
    return dataSet, featureNames, featureNamesSet

(7) Drawing

# Can display Chinese
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.serif'] = ['SimHei']

# Bifurcation node, that is, decision node
decisionNode = dict(boxstyle="sawtooth", fc="0.8")

# Leaf node
leafNode = dict(boxstyle="round4", fc="0.8")

# Arrow style
arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
    Draw a node
    :param nodeTxt: Text information describing the node
    :param centerPt: Coordinates of text
    :param parentPt: The coordinates of the point, which also refers to the coordinates of the parent node
    :param nodeType: Node type,It is divided into leaf node and decision node
    :return:
    """
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def getNumLeafs(myTree):
    """
    Gets the number of leaf nodes
    :param myTree:
    :return:
    """
    # Count the total number of leaf nodes
    numLeafs = 0

    # Get the current first key, that is, the root node
    firstStr = list(myTree.keys())[0]

    # Get the content corresponding to the first key
    secondDict = myTree[firstStr]

    # Recursively traversing leaf nodes
    for key in secondDict.keys():
        # If the key corresponds to a dictionary, it is called recursively
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        # If not, it means that it is a leaf node at this time
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    """
    Number of depth layers obtained
    :param myTree:
    :return:
    """
    # Used to save the maximum number of layers
    maxDepth = 0

    # Get root node
    firstStr = list(myTree.keys())[0]

    # Get the content corresponding to the key
    secondDic = myTree[firstStr]

    # Traverse all child nodes
    for key in secondDic.keys():
        # If the node is a dictionary, it is called recursively
        if type(secondDic[key]).__name__ == 'dict':
            # Depth of child node plus 1
            thisDepth = 1 + getTreeDepth(secondDic[key])

        # This indicates that this is a leaf node
        else:
            thisDepth = 1

        # Replace maximum layers
        if thisDepth > maxDepth:
            maxDepth = thisDepth

    return maxDepth


def plotMidText(cntrPt, parentPt, txtString):
    """
    Calculate the middle position between the parent node and the child node, and fill in the information
    :param cntrPt: Child node coordinates
    :param parentPt: Parent node coordinates
    :param txtString: Filled text information
    :return:
    """
    # Calculate the middle position of the x-axis
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    # Calculate the middle position of the y-axis
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    # Draw
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
    """
    Draw all nodes of the tree and draw recursively
    :param myTree: tree
    :param parentPt: Coordinates of the parent node
    :param nodeTxt: Text information of the node
    :return:
    """
    # Calculate the number of leaf nodes
    numLeafs = getNumLeafs(myTree=myTree)

    # Calculate the depth of the tree
    depth = getTreeDepth(myTree=myTree)

    # Get the information content of the root node
    firstStr = list(myTree.keys())[0]

    # Calculate the middle coordinates of the current root node in all child nodes, that is, the offset of the current X-axis plus the calculated center position of the root node as the x-axis (for example, the first time: the initial x-offset is: - 1/2W, the calculated center position of the root node is: (1+W)/2W, add to get: 1 / 2), and the current Y-axis offset is the y-axis
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

    # Draw the connection between the node and the parent node
    plotMidText(cntrPt, parentPt, nodeTxt)

    # Draw the node
    plotNode(firstStr, cntrPt, parentPt, decisionNode)

    # Get the subtree corresponding to the current root node
    secondDict = myTree[firstStr]

    # Calculate the new y-axis offset and move down 1/D, that is, the drawing y-axis of the next layer
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD

    # Loop through all key s
    for key in secondDict.keys():
        # If the current key is a dictionary and there are subtrees, it will be traversed recursively
        if isinstance(secondDict[key], dict):
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            # Calculate the new X-axis offset, that is, the x-axis coordinate drawn by the next leaf moves 1/W to the right
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            # Open the annotation to observe the coordinate changes of leaf nodes
            # print((plotTree.xOff, plotTree.yOff), secondDict[key])
            # Draw leaf node
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            # Draw the content of the middle line between the leaf node and the parent node
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

    # Before returning to recursion, you need to increase the offset of the y-axis and move it up by 1/D, that is, return to draw the y-axis of the previous layer
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


def createPlot(inTree):
    """
    Decision tree to be drawn
    :param inTree: Decision tree dictionary
    :return:
    """
    # Create an image
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    # Calculate the total width of the decision tree
    plotTree.totalW = float(getNumLeafs(inTree))
    # Calculate the total depth of the decision tree
    plotTree.totalD = float(getTreeDepth(inTree))
    # The initial x-axis offset, that is - 1/2W, moves 1/W to the right each time, that is, the x coordinates drawn by the first leaf node are: 1/2W, the second: 3/2W, the third: 5/2W, and the last: (W-1)/2W
    plotTree.xOff = -0.5/plotTree.totalW
    # The initial y-axis offset, moving down or up 1/D each time
    plotTree.yOff = 1.0
    # Call the function to draw the node image
    plotTree(inTree, (0.5, 1.0), '')
    # draw
    plt.show()

3. Results

dataSet, featureNames, featureNamesSet=readWatermelonDataSet()
testTree= createFullDecisionTree(dataSet, featureNames, featureNamesSet,featureNames)
createPlot(testTree)

3, References

https://blog.csdn.net/qq_47281915/article/details/120928948?spm=1001.2014.3001.5501

https://blog.csdn.net/cxq_baby/article/details/97510434

Tags: Algorithm Machine Learning Decision Tree

Posted on Sat, 06 Nov 2021 08:05:11 -0400 by ricroma