4

I am trying to find the time complexity of a binary decision tree algorithm. I have understood that at each node, the complexity is bounded by the complexity of searching the best attribute O(m nlog n) knowing that m is the number of features and n is the number of exemples in the training set. I think we should multiply O(m nlog n) by the number of nodes to find the complexity of the whole algorithm, is it? I don't understand why in some resources, the complexity of the decision tree is considered only O(m nlog n)!

Can anyone explain this? Is there any difference in the computation of the complexity, depending on whether categorical attributes or continuous are used?

Adrian Mole
  • 49,934
  • 160
  • 51
  • 83
Sarra
  • 41
  • 1

2 Answers2

0

You only have to sort once, which costs m*nlog(n). Actually building the decision tree from the sorted array is asymptotically less than this.

For example, let's construct the initial node of the decision tree. To do this, we take the greedy approach and iterate over each pivot point (n) in each feature (m), calculating the loss each time and giving a complexity of order n*m. The next layer of the tree requires a similar complexity, even though the data is partitioned because both partitions are iterated over. Therefore the complexity to build the tree is n*m*d. This is typically negligible when compared to the sorting time unless your tree depth is comparable with log(n).

In the case where your decision tree is not given a maximum depth but is built until each datapoint is classified, the depth can be anywhere from log(n) and n for completely balanced and unbalanced trees, respectively.

0

The naive implementation is to multiply m*nlog(n) by the number of nodes which is log(n) in the best case (balanced tree) and n in the worst case.

But by using caching, the sorting can be done once for all in O(m*nlog(n)). Then at each node, the computational time complexity will be O(nm) to find the best split at each node as the sorting is already done. Scikit-learn claims to reduce this time complexity to mlog(n). src: https://scikit-learn.sourceforge.net/dev/modules/tree.html#complexity

Therefore, the overall time complexity is O(m*nlog(n)) + O(n*mlog(n), which is roughly O(m*nlog(n)).

M . Franklin
  • 170
  • 1
  • 9
  • `O(m*nlog(n)) + O(n*mlog(n)` is the time complexity in the best case, in the worst case it is `O(m*nlog(n)) + O(m*n^2)` which is roughly `O(m*n^2)` – M . Franklin Mar 31 '23 at 10:03