I just spent hours banging my head about the following problem in python, and I can't quite figure out why it happened.
Say I'm making a decision tree class:
class DecisionTree:
def __init__(self, name, children = dict()):
self.name = name
self.children = children # <------- here's the problem
def add_child(self, child_name, child_tree):
self.children[child_name] = child_tree
def pprint(self, depth = 0):
for childkey in self.children:
tabs = " ".join(["\t" for x in xrange(depth)])
print tabs + 'if ' + self.name + ' is ' + str(childkey)
if (self.children[childkey].__class__.__name__ == 'DecisionTree'): # this is a tree, delve deeper
print tabs + 'and...'
child = self.children.get(childkey)
child.pprint(depth + 1 )
else:
val = self.children.get(childkey)
print tabs + 'then predict ' + str( val )
print "\n"
return ''
Now let's build a nonsense tree, and try to print it:
def make_a_tree(depth = 0, counter = 0):
counter += 1
if depth > 3:
return 'bottom'
else:
tree = DecisionTree(str(depth)+str(counter))
for i in range(2):
subtree = make_a_tree(depth+1, counter)
tree.add_child(i, subtree)
return tree
foo = make_a_tree()
foo.pprint()
This code leads to an infinite recursion loop, because the tree structure was (somehow) mistakenly built with the 2nd node of the tree referring to itself.
If I change the line I've marked above (5th one) to tree.children = dict()
, then things work properly.
I can't wrap my head around what's happening here. The intention behind the code as written is to take an argument for "children", and if none is passed, create an empty dictionary and use that as children.
I'm pretty new to Python, and I'm trying to make this a learning experience. Any help would be appreciated.