decison tree

__author__ = 'HM'
from treelib import *
import math
import uuid
#-------------------------------------------------------------------
#data and variant definition
'''
attribute_discrete_pool = {'occupation':set('student','teacher'),'sex':set('male','female')}
attribute_continuous_pool = {'salary':(0,10000),'age':(0,100)}

data_set_raw = [{'class_label':1,'name':'Tom','age':10},
                {'class_label':0,'name':'Jerry','age':90}
]
class_label_pool = [1,0]
'''
f = open('data.txt','r')
first_line = f.readline().split()
attributes = first_line[:-1]
attr_len = len(attributes)
classname = first_line[-1]
data_set_raw = []
class_label_pool = set()
attribute_discrete_pool = {}
for line in f:
    raw_data = line.split()
   # new_record = {classname:raw_data[-1]}
    new_record = {'class_label':raw_data[-1]}
    class_label_pool.add(raw_data[-1])
    for i in xrange(attr_len):
        attribute_name = attributes[i]
        new_record[attribute_name] = raw_data[i]

        attribute_discrete_pool[attribute_name]=attribute_discrete_pool.get(attribute_name,set()).union(set([raw_data[i]]))
    data_set_raw.append(new_record)
print attribute_discrete_pool
'''
for i in data_set_raw:
    print i
'''
#--------------------------------------------------------------------
class innerNode(Node):
    def __init__(self,attribute,parent_node_assert,identifier=None, expanded=True):
        self.parent_node_assert = parent_node_assert
        self.tag = attribute

        self._identifier = self.set_identifier(identifier)
        self.expanded = expanded
        self._bpointer = None
        self._fpointer = []

    def __str__(self):
        return str(self.parent_node_assert)

class leafNode(Node):
    def __init__(self,class_label,parent_node_assert,identifier=None, expanded=True):
        self.tag = class_label
        self.parent_node_assert = parent_node_assert

        self._identifier = self.set_identifier(identifier)
        self.expanded = expanded
        self._bpointer = None
        self._fpointer = []

    def __str__(self):
        return str(self.class_label)

def info(dataset):
    dataset_len = len(dataset)
    class_num_count = {}
    for d in dataset:
        class_num_count[d['class_label']] = class_num_count.get(d['class_label'],0)+1
    info_sum = 0
    for c in class_num_count:
        p = class_num_count[c]/float(dataset_len)
        info_sum += p*math.log(p,2)
    return -info_sum

def info_a(attribute,dataset):
    sub_dataset = {}
    dataset_len = len(dataset)
    for d in dataset:
        sub_dataset[d[attribute]] = sub_dataset.get(d[attribute],[])+[d]
    infoa_sum = 0
    for key in sub_dataset:
        infoa_sum += (len(sub_dataset[key])/float(dataset_len))*info(sub_dataset[key])
    return infoa_sum,sub_dataset

def info_gain(attribute, dataset):
    '''
    compute info gain for current attribute,
    return info gain and split sub data set
    '''
    if attribute in attribute_discrete_pool.keys():
        # for discrete attribute:
        infoa,sub_dataset = info_a(attribute,dataset)
        gain_a = info(dataset)-infoa
        return gain_a,sub_dataset
    else:
        # for continuous attribute:
        pass

def get_best_attribute(attribute_list,dataset):
    #attribute_scorelist = {}
    max_score = -1
    max_attribute = None
    compute_funct = info_gain
    max_sub_dataset = {}
    for attribute in attribute_list:
        #attribute_scorelist[attribute] = compute_funct(attribute,dataset)
        current_score,sub_dataset = compute_funct(attribute,dataset)
        if current_score > max_score:
            max_score = current_score
            max_attribute = attribute
            max_sub_dataset = sub_dataset
    return max_attribute,max_sub_dataset#sub_dataset format {attribute_value:#}

true = lambda x:True

def vote_class_label(dataset):
    class_label_count={}
    for d in dataset:
        class_label_count[d['class_label']]=class_label_count.get(d['class_label'],0)+1
    return sorted(class_label_count.items(),key = lambda d:d[1])[0][0]

def create_decision_tree(attribute_list,dataset,assert_func,parent_id):
    '''
    this function create decision tree recursively
    input:    attribute_list: current unused attribute set
              dataset: data set
              assert_func: parent node assert
              parent_id: parent id uuid
    output: decision tree
    '''
    #1.exit recurse
    if len(attribute_list)==0:
        tree = Tree()
        inner_node_id = uuid.uuid1()
        tree.add_node(innerNode(None,assert_func,inner_node_id))
        tree.add_node(leafNode(vote_class_label(dataset),true,uuid.uuid1()),parent=inner_node_id)
        return tree

    #2.recurse
    print attribute_list
    attribute_list_local = attribute_list[:]
    best_attribute,sub_dataset = get_best_attribute(attribute_list_local,dataset)
    attribute_list_local.remove(best_attribute)
    print "the best attribute is:",best_attribute

    node_id = uuid.uuid1()
    tree = Tree()
    tree.add_node(innerNode(best_attribute,assert_func,node_id))

    for attribute_value in sub_dataset:
        # sub_data:attribute value  ;sub_dataset[sub_data]: data set
        print '-'*40
        print attribute_value
        func = lambda x: x==attribute_value
        sub_tree = create_decision_tree(attribute_list_local,sub_dataset[attribute_value],func,node_id)
        tree.paste(node_id,sub_tree)
    return tree

    '''
    true = lambda x:True

    tree = Tree()
    tree.add_node(innerNode('salary', true,1))
    tree.add_node(innerNode(None,lambda x:x>10,2),parent=1)
    tree.add_node(innerNode(None,lambda x:x<=10,3),parent=1)

    tree.add_node(leafNode("yes",true,4),parent=2)
    tree.add_node(leafNode("no",true,5),parent=3)

    tree.show()
    return tree
    '''

def predict(v,decision_tree):

    current_node = decision_tree.get_node(decision_tree.root)
    current_node_sons_ids = current_node.fpointer
    v_attribute = v[current_node.tag]
    while 1:
      # print current_node.class_label
       for i in current_node_sons_ids:
           son = decision_tree.get_node(i)
           if son.parent_node_assert(v_attribute):
               current_node = son
               break
       if type(current_node)==leafNode:# not type(leafNode)!!!
           return current_node.tag
       if current_node.tag<>None:
           v_attribute = v[current_node.tag]
       current_node_sons_ids = current_node.fpointer

#print predict({'salary':20}, create_decision_tree())
root_node_id = uuid.uuid1()
root_node = innerNode('null',true,root_node_id)
tree = create_decision_tree(attribute_discrete_pool.keys(),data_set_raw,true,root_node_id)
tree.show()

另外:官方包 https://pypi.python.org/pypi/DecisionTree

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章