18

Tree class in MATLAB

I am implementing a tree data structure in MATLAB. Adding new child nodes to the tree, assigning and updating data values related to the nodes are typical operations that I expect to execute. Each node has the same type of data associated with it. Removing nodes is not necessary for me. So far, I've decided on a class implementation inheriting from the handle class to be able to pass references to nodes around to functions that will modify the tree.

Edit: December 2nd

First of all, thanks for all the suggestions in the comments and answers so far. They have already helped me to improve my tree class.

Someone suggested trying digraph introduced in R2015b. I have yet to explore this, but seeing as it does not work as a reference parameter similarly to a class inheriting from handle, I am a bit sceptical how it will work in my application. It is also at this point not yet clear to me how easy it will be to work with it using custom data for nodes and edges.

Edit: (Dec 3rd) Further information on the main application: MCTS

Initially, I assumed the details of the main application would only be of marginal interest, but since reading the comments and the answer by @FirefoxMetzger, I realise that it has important implications.

I am implementing a type of Monte Carlo tree search algorithm. A search tree is explored and expanded in an iterative manner. Wikipedia offers a nice graphical overview of the process: Monte Carlo tree search

In my application I perform a large number of search iterations. On every search iteration, I traverse the current tree starting from the root until a leaf node, then expand the tree by adding new nodes, and repeat. As the method is based on random sampling, at the start of each iteration I do not know which leaf node I will finish at on each iteration. Instead, this is determined jointly by the data of nodes currently in the tree, and the outcomes of random samples. Whatever nodes I visit during a single iteration have their data updated.

Example: I am at node n which has a few children. I need to access data in each of the children and draw a random sample that determines which of the children I move to next in the search. This is repeated until a leaf node is reached. Practically I am doing this by calling a search function on the root that will decide which child to expand next, call search on that node recursively, and so on, finally returning a value once a leaf node is reached. This value is used while returning from the recursive functions to update the data of the nodes visited during the search iteration.

The tree may be quite unbalanced such that some branches are very long chains of nodes, while others terminate quickly after the root level and are not expanded further.

Current implementation

Below is an example of my current implementation, with example of a few of the member functions for adding nodes, querying the depth or number of nodes in the tree, and so on.

classdef stree < handle
    %   A class for a tree object that acts like a reference
    %   parameter.
    %   The tree can be traversed in both directions by using the parent
    %   and children information.
    %   New nodes can be added to the tree. The object will automatically
    %   keep track of the number of nodes in the tree and increment the
    %   storage space as necessary.

    properties (SetAccess = private)
        % Hold the data at each node
        Node = { [] };
        % Index of the parent node. The root of the tree as a parent index
        % equal to 0.
        Parent = 0;
        num_nodes = 0;
        size_increment = 1;
        maxSize = 1;
    end

    methods
        function [obj, root_ID] = stree(data, init_siz)
            % New object with only root content, with specified initial
            % size
            obj.Node = repmat({ data },init_siz,1);
            obj.Parent = zeros(init_siz,1);
            root_ID = 1;
            obj.num_nodes = 1;
            obj.size_increment = init_siz;
            obj.maxSize = numel(obj.Parent);
        end

        function ID = addnode(obj, parent, data)
            % Add child node to specified parent
            if obj.num_nodes < obj.maxSize
                % still have room for data
                idx = obj.num_nodes + 1;
                obj.Node{idx} = data;
                obj.Parent(idx) = parent;
                obj.num_nodes = idx;
            else
                % all preallocated elements are in use, reserve more memory
                obj.Node = [
                    obj.Node
                    repmat({data},obj.size_increment,1)
                    ];

                obj.Parent = [
                    obj.Parent
                    parent
                    zeros(obj.size_increment-1,1)];
                obj.num_nodes = obj.num_nodes + 1;

                obj.maxSize = numel(obj.Parent);

            end
            ID = obj.num_nodes;
        end

        function content = get(obj, ID)
            %% GET  Return the contents of the given node IDs.
            content = [obj.Node{ID}];
        end

        function obj = set(obj, ID, content)
            %% SET  Set the content of given node ID and return the modifed tree.
            obj.Node{ID} = content;
        end

        function IDs = getchildren(obj, ID)
            % GETCHILDREN  Return the list of ID of the children of the given node ID.
            % The list is returned as a line vector.
            IDs = find( obj.Parent(1:obj.num_nodes) == ID );
            IDs = IDs';
        end
        function n = nnodes(obj)
            % NNODES  Return the number of nodes in the tree.
            % Equal to root + those whose parent is not root.
            n = 1 + sum(obj.Parent(1:obj.num_nodes) ~= 0);
            assert( obj.num_nodes == n);
        end

        function flag = isleaf(obj, ID)
            % ISLEAF  Return true if given ID matches a leaf node.
            % A leaf node is a node that has no children.
            flag = ~any( obj.Parent(1:obj.num_nodes) == ID );
        end

        function depth = depth(obj,ID)
            % DEPTH return depth of tree under ID. If ID is not given, use
            % root.
            if nargin == 1
                ID = 0;
            end
            if obj.isleaf(ID)
                depth = 0;
            else
                children = obj.getchildren(ID);
                NC = numel(children);
                d = 0; % Depth from here on out
                for k = 1:NC
                    d = max(d, obj.depth(children(k)));
                end
                depth = 1 + d;
            end
        end
    end
end

However, performance at times is slow, with operations on the tree taking up most of my computation time. What specific ways would there be to make the implementation more efficient? It would even be possible to change the implementation to something else than the handle inheritance type if there are performance gains.

Profiling results with current implementation

As adding new nodes to the tree is the most typical operation (along with updating the data of a node), I did some profiling on that. I ran the profiler on the following benchmarking code with Nd=6, Ns=10.

function T = benchmark(Nd, Ns)
% Tree benchmark. Nd: tree depth, Ns: number of nodes per layer
% Initialize tree
T = stree(rand, 10000);
add_layers(1, Nd);
    function add_layers(node_id, num_layers)
        if num_layers == 0
            return;
        end
        child_id = zeros(Ns,1);
        for s = 1:Ns
            % add child to current node
            child_id(s) = T.addnode(node_id, rand);

            % recursively increase depth under child_id(s)
            add_layers(child_id(s), num_layers-1);
        end
    end
end

Results from the profiler: Profiler results

R2015b performance


It has been discovered that R2015b improves the performance of MATLAB's OOP features. I redid the above benchmark and indeed observed an improvement in performance:

R2015b profiler result

So this is already good news, although further improvements are of course accepted ;)

Reserving memory differently

It was also suggested in the comments to use

obj.Node = [obj.Node; data; cell(obj.size_increment - 1,1)];

to reserve more memory rather than the current approach with repmat. This improved performance slightly. I should note that my benchmark code is for dummy data, and since the actual data is more complicated this is likely to help. Thanks! Profiler results below:

zeeMonkeez memory reserve style

Questions on even further increasing performance

  1. Perhaps there is an alternative way to maintain memory for the tree that is more efficient? Sadly, I typically don't know ahead of time how many nodes there will be in the tree.
  2. Adding new nodes and modifying the data of existing nodes are the most typical operations I do on the tree. As of now, they actually take up most of the processing time of my main application. Any improvements on these functions would be most welcome.

Just as a final note, I would ideally like to keep the implementation as pure MATLAB. However, options such as MEX or using some of the integrated Java functionalities may be acceptable.

Community
  • 1
  • 1
mikkola
  • 3,376
  • 1
  • 19
  • 41
  • 1
    Running the `profiler` can illuminate quite a bit in your code in terms of performance. Run it once and see where the code is exceptionally slow, it'll give you a pointer where to start improving. – Adriaan Nov 13 '15 at 15:26
  • 2
    [Matlab OOP adds a significant overhead unless you use Matlab 2015b or newer](http://stackoverflow.com/questions/1693429/is-matlab-oop-slow-or-am-i-doing-something-wrong) which probably causes the problems. Not using `handle` probably won't help. – Daniel Nov 13 '15 at 16:05
  • 1
    @Adriaan thanks for the suggestion. I added some profiler data. – mikkola Nov 13 '15 at 16:27
  • @Daniel very interesting read, thank you! I've certainly noticed that "vectorizing" Matlab OOP code by structuring objects to wrap arrays instead of creating arrays of objects has a huge effect on performance. However, based on the profiler results I also suspect there is something else going on here, as well. – mikkola Nov 13 '15 at 17:32
  • If switching to R2015b for the OOP performance benefits, you might also want to look at implementing the tree structure using the new [`digraph`](http://mathworks.com/help/matlab/ref/digraph-object.html) class instead. I haven't used it extensively myself but I'd expect anything that can be implemented in terms of its native operations to be much better optimised. – Will Dec 01 '15 at 14:31
  • If you often find you add nodes in batches, it might be beneficial to add a batch add function. – zeeMonkeez Dec 01 '15 at 20:56
  • 1
    Also, depending on what your data is, using `repmat` to allocate node data might add a lot of overhead. Why not initialize with `obj.Node = [obj.Node; data; cell(obj.size_increment - 1,1)];`? – zeeMonkeez Dec 01 '15 at 21:06
  • 1
    How are you accessing the nodes, when you modify data? More specificly why are you saving the node's parent and not its childs? From the looks of it, you might be faster using a single look up table or struct to store your data. –  Dec 02 '15 at 13:10
  • @FirefoxMetzger I edited the question to respond to your comment. Your suggestion is sensible, it's just that current profiling results show most of the time spent elsewhere. I will however look into implementing your suggestion as time permits. – mikkola Dec 02 '15 at 13:26
  • @Will thanks for the suggestion! So far looking at the documentation, I'm not convinced I want to go for `digraph`. I edited the question to give some justification. – mikkola Dec 02 '15 at 13:28
  • @zeeMonkeez I only add nodes one at a time, so batch is not needed. Your suggestion on not using `repmat` seems useful, thank you! I observed improvements even with the dummy data used in benchmarking. Updated the post to include profiler results. – mikkola Dec 02 '15 at 13:34
  • 1
    @mikkola `digraph` implements both nodes and edges as MATLAB tables, so storing additional data for nodes is just a case of adding columns to that table. If the data for each node itself needs to be handled in an object-oriented way, you could store an object handle in a column. For object-oriented traversal of the tree itself, the cleanest approach would probably to be to subclass `digraph` itself though. – Will Dec 02 '15 at 14:23
  • Is there a reason why you flattened the tree? I mean, why are the elements stored in a linear data array? For a tree each children should be a tree itself. – NicolaSysnet Dec 03 '15 at 09:39
  • @NicolaSysnet I very well understand your point. The reason is related to Matlab's OOP weaknesses. It is (at least up until R2015b) much faster to have an object wrap an array instead of having an array of objects, see [this answer](http://stackoverflow.com/a/1745686/5471520). – mikkola Dec 03 '15 at 09:55
  • 1
    What is true for someone isn't true for everyone: in your application you keep extending a cell array, that is a very slow operation in Matlab. If your application needs extending and shrinking of the array, and you **never** need to apply a function to all the elements of the tree (as in `cellfun(fun,stree.Node)`) then the overhead of the Matlab OOP weaknesses is balanced by the savings due the extension of the cell array. – NicolaSysnet Dec 03 '15 at 10:06

3 Answers3

9

TL:DR You deep copy the entire data stored on each insertation, initialize the parent and Node cell bigger then what you expect to need.

Your data does have a tree structure, however you are not utilising this in your implementation. Instead the implemented code is a computational hungry version of a look up table (actually 2 tables), that stores the data and the relational data for the tree.

The reasons I am saying this are the following:

  • To insert you call stree.addnote(parent, data), which will store all data in the tree object stree's fields Node = {} and Parent = []
  • you seem to know prior which element in your tree you want to access as the search code is not given (if you use stree.getchild(ID) for it I have some bad news)
  • once you processed a node you trace it back using find() which is a list search

By no means does that mean the implementation is clumsy for the data, it may even be the best, depending on what you are doing. However it does explain your memory allocation issues and gives hints on how to resolve them.


Keep Data as lookup table

One of the ways to store data is to keep the underlying look up table. I would only do this, if you know the ID of the first element you want to modify without searching for it. This case allows you to make your structure more efficient in two steps.

First initialise your arrays bigger then what you expect you need to store the data. If the look up table's capacity is exceeded, a new one is initialized, which is X fields larger, and a deep-copy of the old data is made. If you need to expand capcity once or twice (during all insertations) it might not be an issue, but in your case a deep copy is made for ever insertation!

Second I would change the internal structure and merge the two tables Node and Parent. The reason for this is that back-propagation in your code takes O(depth_from_root * n), where n is the number of nodes in your table. This is because find() will iterate over the entire table for each parent.

Instead you can implement something similar to

table = cell(n,1) % n bigger then expected value
end_pointer = 1 % simple pointer to the first free value

function insert(data,parent_ID)
    if end_pointer < numel(table)
        content.data = data;
        content.parent = parent_ID;
        table{end_pointer} = content;
        end_pointer = end_pointer + 1;
    else
        % need more space, make sure its enough this time
        table = [table cell(end_pointer,1)];
        insert(data,parent_ID);
    end
end

function content = get_value(ID)
    content = table(ID);
end

This instantly gives you access to the parent's ID without the need to find() it first, saving n iterations each step, so afford becomes O(depth). If you do not know your initial node, then you have to find() that one, which costs O(n).

Note that this structure has no need for is_leaf(), depth(), nnodes() or get_children(). If you still need those I need more insight in what you want to do with your data, as this highly influences a proper structure.


Tree Structure

This structure makes sense, if you never know the first node's ID and thus allways have to search for it.

The benefit is that the search for an arbitrary note works with O(depth), so searching is O(depth) instead of O(n) and back propagating is O(depth^2) instead of O(depth + n). Note that depth can be anything from log(n) for a perfectly balanced tree, that may be possible depending on your data, to n for the degenerated tree, that just is a linked list.

However to suggest something proper I would require more insight, as every tree structure kind of has its own nich. From what I can see so far, I'd suggest an unbalanced tree, that is 'sorted' by the simple order given by a nodes wanted parent. This may be further optimized depending on

  • is it possible to define a total order on your data
  • how do you treat double values (same data appearing twice)
  • what scale is your data (thousands, millions, ...)
  • is a lookup / search allways paired with back propagation
  • how long are the chains of 'parent-child' on your data (or how balanced and deep will the tree be using this simple order)
  • is there allways just one parent or is the same element inserted twice with different parents

I'll happily provide example code for above tree, just leave me a comment.

EDIT: In your case an unbalanced tree (that is construted paralell to doing the MCTS) seems to be the best option. The code below assumes that the data is split in state and score and further that a state is unique. If it isn't this will still work, however there is a possible optimisation to increase MCTS preformance.

classdef node < handle
    % A node for a tree in a MCTS
    properties
        state = {}; %some state of the search space that identifies the node
        score = 0;
        childs = cell(50,1);
        num_childs = 0;
    end
    methods
        function obj = node(state)
            % for a new node simulate a score using MC
            obj.score = simulate_from(state); % TODO implement simulation state -> finish
            obj.state = state;
        end
        function value = update(obj)
            % update the this node using MC recursively
            if obj.num_childs == numel(obj.childs)
                % there are to many childs, we have to expand the table
                obj.childs = [obj.childs cell(obj.num_childs,1)];
            end
            if obj.do_exploration() || obj.num_childs == 0
                % explore a potential state
                state_to_explore = obj.explore();

                %check if state has already been visited
                terminate = false;
                idx = 1;
                while idx <= obj.num_childs && ~terminate
                    if obj.childs{idx}.state_equals(state_to_explore)
                        terminate = true;
                    end
                    idx = idx + 1;
                end

                %preform the according action based on search
                if idx > obj.num_childs
                    % state has never been visited
                    % this action terminates the update recursion 
                    % and creates a new leaf
                    obj.num_childs = obj.num_childs + 1;
                    obj.childs{obj.num_childs} = node(state_to_explore);
                    value = obj.childs{obj.num_childs}.calculate_value();
                    obj.update_score(value);
                else
                    % state has been visited at least once
                    value = obj.childs{idx}.update();
                    obj.update_score(value);
                end
            else
                % exploit what we know already
                best_idx = 1;
                for idx = 1:obj.num_childs
                    if obj.childs{idx}.score > obj.childs{best_idx}.score
                        best_idx = idx;
                    end
                end
                value = obj.childs{best_idx}.update();
                obj.update_score(value);
            end
            value = obj.calculate_value();
        end
        function state = explore(obj)
            %select a next state to explore, that may or may not be visited
            %TODO
        end
        function bool = do_exploration(obj)
            % decide if this node should be explored or exploited
            %TODO
        end
        function bool = state_equals(obj, test_state)
            % returns true if the nodes state is equal to test_state
            %TODO
        end
        function update_score(obj, value)
            % updates the score based on some value
            %TODO
        end
        function calculate_value(obj)
            % returns the value of this node to update previous nodes
            %TODO
        end
    end
end

A few comments on the code:

  • depending on the setup the obj.calculate_value() might not be needed. E.g. if it is some value that can be computed by evaluating the child's scores alone
  • if a state can have multiple parents it makes sense to reuse the note object and cover it in the structure
  • as each node knows all its children a subtree can be easily generated using node as root node
  • searching the tree (without any update) is a simple recursive greedy search
  • depending on the branching factor of your search, it might be worth to visit each possible child once (upon initialization of the node) and later do randsample(obj.childs,1) for exploration as this avoids copying / reallocating of the child array
  • the parent property is encoded as the tree is updated recursively, passing value to the parent upon finishing the update for a node
  • The only time I reallocate memory is when a single node has more then 50 childs any I only do reallocation for that individual node

This should run a lot faster, as it just worries about whatever part of the tree is chosen and does not touch anything else.

  • Thank you for the response! My application is a type of a Monte Carlo tree search. I did not initially realise how much it would impact designing the tree, sorry about that! I updated the question to include some more details. I hope it helps you to focus your answer even better. – mikkola Dec 03 '15 at 09:58
  • Another followup: For the particular application and the way I traverse the tree, do you believe a lookup table implementation would actually be more appropriate? I had not considered it before. The tree search aspect, the visualizations and other documents I've read on MCTS lured me right into implementing it with a tree data structure as well. – mikkola Dec 03 '15 at 10:07
  • Am I right in assuming that your 'data' consists of some fancy state + a score value for that state? And further that the update woun't affect the state, but rather change the score? –  Dec 03 '15 at 11:10
  • Exactly. The data first of all consists of an identifier (typically integer) that links the node to something that is meaningful in the underlying optimization task, but which requires also knowing the *history*, i.e. the identifiers of the nodes through which this node was reached, to make sense overall. Furthermore, there is a score (floating point), and a visitation count (positive integer). Update changes the score and count only. – mikkola Dec 03 '15 at 12:04
  • To continue with the lookup table idea, an alternative might be to save this history for every node removing the need for the history. – mikkola Dec 03 '15 at 12:10
  • this mainly depends on if your underlying space is markov or not. If you somehow can fullfill the markov property, then it is usually wise to use this so that the entire history woun't be needed to describe things. –  Dec 03 '15 at 12:36
  • I agree, and I am working on a Markovian decision process optimization task. But the process is partially observable, so I do need to keep track of either 1) a probability distribution over the state or 2) a sequence of actions and observations (a history) from which I can reconstruct the pdf if necessary. The latter is much more convenient since the history is a sequence pairs of integers compared to a real-valued function in the former. But I think we are getting sidetracked :) – mikkola Dec 03 '15 at 12:40
  • @mikkola I've updated my answer. Further it would matter if you had access to the full model. This can be encoded in the datastructure to save runtime in the MCTS (as you could 'reuse' already visited states upon finding a new path to get there) –  Dec 03 '15 at 13:54
6

I know that this might sound stupid... but how about keeping the number of free nodes instead of total number of nodes? This would require comparison against a constant (which is zero), which is single property access.

One other voodoo improvement would be moving .maxSize near .num_nodes, and placing both those before the .Node cell. Like this their position in memory won't change relative to the beginning of the object because of the growth of .Node property (the voodoo here being me guessing the internal implementation of objects in MATLAB).

Later Edit When I profiled with the .Node moved at the end of the property list, the bulk of the execution time was consumed by extending the .Node property, as expected (5.45 seconds, compared to 1.25 seconds for the comparison you mentioned).

  • Interesting! Comparison to a constant seems like a good idea, and I did observe a slight improvement. Not sure about the `.Node` property though - I did not observe an improvement switching the location of the property. My understanding is that cell array elements do not necessarily occupy contiguous blocks in memory, so the effect on performance is not easy to predict. – mikkola Nov 13 '15 at 18:33
  • @mikkola The thing is, the elements of a cell array are not in contiguous memory area, but their -- let's call them -- addresses must be. Conceptually a cell array is like an array of pointers, and that array itself grows in a contiguous memory area. By the way, do you clear your classes every time you run the benchmark? This usually worsens the performance of the newly instantiated objects (JIT stuff, plus the metaclass needs to be re-created again) –  Nov 13 '15 at 18:40
  • Thanks for the clarification regarding address pointers! I repeated the benchmarking adding `clear all` and `clear classes` each time. I still only saw nonsignificant improvement in the performance, although I agree your suggestions make sense. – mikkola Nov 15 '15 at 11:36
  • @mikkola I'm sorry that my suggestions don't bring significant results. Again, when the implementation is opaque all one can do is try several theories that might make sense (that's what I called voodoo programming). Eventually if one keeps pounding at the system it may find its sweet spot, or if not just wait until the next release that will improve the performance. –  Nov 16 '15 at 07:11
4

You can try to allocate a number of elements that is proportional to the number of elements you have actually filled: this is the standard implementation for std::vector in c++

obj.Node = [obj.Node; data; cell(q * obj.num_nodes,1)];

I don't remember exactly but in MSCC q is 1 while it is .75 for GCC.


This is a solution using Java. I don't like it very much, but it does its job. I implemented the example you extracted from wikipedia.

import javax.swing.tree.DefaultMutableTreeNode

% Let's create our example tree
top = DefaultMutableTreeNode([11,21])
n1 = DefaultMutableTreeNode([7,10])
top.add(n1)
n2 = DefaultMutableTreeNode([2,4])
n1.add(n2)
n2 = DefaultMutableTreeNode([5,6])
n1.add(n2)
n3 = DefaultMutableTreeNode([2,3])
n2.add(n3)
n3 = DefaultMutableTreeNode([3,3])
n2.add(n3)
n1 = DefaultMutableTreeNode([4,8])
top.add(n1)
n2 = DefaultMutableTreeNode([1,2])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n1 = DefaultMutableTreeNode([0,3])
top.add(n1)

% Element to look for, your implementation will be recursive
searching = [0 1 1];
idx = 1;
node(idx) = top;
for item = searching,
    % Java transposes the matrices, remember to transpose back when you are reading
    node(idx).getUserObject()'
    node(idx+1) = node(idx).getChildAt(item);
    idx = idx + 1;
end
node(idx).getUserObject()'

% We made a new test...
newdata = [0, 1]
newnode = DefaultMutableTreeNode(newdata)
% ...so we expand our tree at the last node we searched
node(idx).add(newnode)

% The change has to be propagated (this is where your recursion returns)
for it=length(node):-1:1,
    itnode=node(it);
    val = itnode.getUserObject()'
    newitemdata = val + newdata
    itnode.setUserObject(newitemdata)
end

% Let's see if the new values are correct
searching = [0 1 1 0];
idx = 1;
node(idx) = top;
for item = searching,
    node(idx).getUserObject()'
    node(idx+1) = node(idx).getChildAt(item);
    idx = idx + 1;
end
node(idx).getUserObject()'
NicolaSysnet
  • 486
  • 2
  • 10
  • This was going to be my recommendation. Every time the vector runs out of room, it reserves 2x. Computers like powers of 2 :) – Nick Dec 02 '15 at 17:18
  • Thank you for the answer. Clarification: `searching = [0 1 1];` means that we want to find a node through the following strategy: first child of root, second child of that node, and finally ending at the second child of that one - correct? A quick test seems to indicate that the `UserObject` may be any Matlab supported type, hopefully this is true as well. – mikkola Dec 03 '15 at 12:24
  • The interpretation for the `searching` array is correct (arrays are 0-based in Java). There are limitations to the objects you can store in a Java array (like handles), you have to try with your effective objects. – NicolaSysnet Dec 03 '15 at 12:30