2

I'm looking for a basic pseudo-code outline here.

My goal is to code a classification tree from scratch (I'm learning machine learning and want to get intuition). But my training data is huge: 40000 examples and 1000 features. Given that the upper bound for number of splits needed is 240000, I'm lost as how to keep track of all these partitioned datasets.

Say I start with the full dataset and take one split. Then I can save the 20000ish examples that fell on one side of the split into a dataset, and re-run the splitting algorithm to find the greedy split for that dataset. Then say I keep doing this, splitting along the leftmost branches of the tree dozens of times.

When I'm satisfied with all my leftmost splits, then what? How do I store up to 240000 separate subsets? And how do I keep track of all the splits I've taken for when I'm classifying a test example? It's the organization of the code that's not making sense to me.

Brian Tompsett - 汤莱恩
  • 5,753
  • 72
  • 57
  • 129
user1956609
  • 2,132
  • 5
  • 27
  • 43
  • It is not N=2^40000 splits, but 2^N=40000, which is N=log2(40000)<16. A tree 16 deep has more than 40000 leaf nodes. 2^40000 is a bogglingly large number. – Cris Luengo Aug 06 '21 at 14:55

2 Answers2

2

Thanks @natan for a detailed answer.

However, if I understand correctly, the main issue you are concern with is how to efficiently track each training sample as it propagates through the decision tree.

This can be done rather easily.

All you need is a vector of size N=40000 with an entry per training sample. This vector will tell you where in the tree each sample is located. Lets call this vector assoc.

How to use this vector?

For my opinion, the most elegant way is to make assoc of type uint32 and use the bits to encode the propagation of each training sample through the tree.

Each bit in assoc(k) represents certain depth of the tree, if this bit is set (1) it means that sample k went right, otherwise it means that sample k went left.

If you decide to go along this strategy, you'll find the following Matlab commands useful bitget, bitset, bitshift and some other bit-wise functions.

Lets consider the following tree

       root
      /    \
     a      b
           / \
          c   d

So, for all examples that went to node a their assoc value is 00b - since they went left at the root (corresponding to zero at the Least Significant Bit (LSB)).

All examples that went to leaf node c, their assoc value is 01b - they went right at the root (LSB=1), then turned left (2nd bit=0).

Finally, all examples that went to leaf node d, their assoc value is 11b - they took too right branches.

Now, How can you find all examples that went through node b ?

It's easy!

>> selNodeB = bitand( assoc, 1 );

All nodes that their LSB is 1 took the right turn at the root and pass through node b.

Shai
  • 111,146
  • 38
  • 238
  • 371
1

If you think there is a way to store 2^40000 bits you haven't realize how big this number is, and you are wrong on about 10000 orders of magnitude. Check Matlab's documentation of classregtree.

I've copied from @Amro's detailed answer (found here) :

" Here are a few common parameters of the classification tree model:

  • x: data matrix, rows are instances, cols are predicting attributes
  • y: column vector, class label for each instance
  • categorical: specify which attributes are discrete type (as opposed to continuous)
  • method: whether to produce classification or regression tree (depend on the class type)
  • names: gives names to the attributes
  • prune: enable/disable reduced-error pruning
  • minparent/minleaf: allows to specify min number of instances in a node if it is to be further split
  • nvartosample: used in random trees (consider K randomly chosen attributes at each node)
  • weights: specify weighted instances
  • cost: specify cost matrix (penalty of the various errors)
  • splitcriterion: criterion used to select the best attribute at each split. I'm only familiar with the Gini index which is a variation of the Information Gain criterion.
  • priorprob: explicitly specify prior class probabilities, instead of being calculated from the training data

A complete example to illustrate the process:

%# load data
load carsmall

%# construct predicting attributes and target class
vars = {'MPG' 'Cylinders' 'Horsepower' 'Model_Year'};
x = [MPG Cylinders Horsepower Model_Year];
y = strcat(Origin,{});

%# train classification decision tree
t = classregtree(x, y, 'method','classification', 'names',vars, ...
                'categorical', [2 4], 'prune','off');
view(t)

%# test
yPredicted = eval(t, x);
cm = confusionmat(y,yPredicted);           %# confusion matrix
N = sum(cm(:));
err = ( N-sum(diag(cm)) ) / N;             %# testing error

%# prune tree to avoid overfitting
tt = prune(t, 'level',2);
view(tt)

%# predict a new unseen instance
inst = [33 4 78 NaN];
prediction = eval(tt, inst)

tree

Glorfindel
  • 21,988
  • 13
  • 81
  • 109
bla
  • 25,846
  • 10
  • 70
  • 101