4

I have been trying to follow algorithm provided here to find k'th smallest value in the binary tree. However, I'm not clear about searching the right subtree (is this the error in my code?).

For example, for the following tree: 9, 3, 6, 5, 8, 13, 14.

   9
  / \
 /   \
3    13
 \     \
  6    14
 / \
5   8

The output I get is:

k = 0, returned 3    
k = 1, returned 6     // should return 5   
k = 2, returned -1    // should return 6
k = 3, returned -1    // should return 8
k = 4, returned 9     
k = 5, returned -1    // should return 13
k = 6, returned -1    // should return 14

In another example - 4, 3, 2, 5, 6 - the output is also partially correct and fails to locate all nodes in the right subtree.

    4
   / \
  3   5
 /     \
2       6

k = 0, returned 2
k = 1, returned 3
k = 2, returned 4
k = 3, returned -1    // should return 5
k = 4, returned -1    // should return 6

Can someone please explain how to locate k'th smallest value in the right subtrees?

My code:

int length(Tree t) {
    if (t == NULL) {
        return 0;
    } else {
        return 1 + length(t->left) + length(t->right); 
    }
}

int findSmallest(Tree t, int k) {

    if (t == NULL) {
        return -1; 
    }

    if (k == length(t->left)) {
        return t->value;
    }

    if (k < length(t->left)) {
        return findSmallest(t->left, k); 
    }

    if (k > length(t->left)) {
        return findSmallest(t->right, (k - length(t->left)));  
    }

    return 0; 

}
J Szum
  • 548
  • 1
  • 4
  • 12

1 Answers1

1

For short, we can get correct result by replacing single line in the function findSmallest

  • original code :
    return findSmallest(t->right, (k - length(t->left))); 
    
  • corrected code :
    return findSmallest(t->right, (k - length(t->left) - 1));
    

In each round of findSmallest, the parameter k represents the remaining number of how many smaller nodes to be found.

The reason we put (k - length(t->left) - 1) here is because we have already found length(t->left) smaller nodes inside the sub-tree of t->left, and another one smaller node which is t itself. We need to find the rest number of smaller nodes inside the tree whose root node is t->right.

Note that the function length(t) returns the number of nodes inside the tree whose root node is t. It doesn't care if the value of the sub node is less than the value of root node or not.


The following describes what's behind it step by step :

  1. Declare the data structure Tree Since the question doesn't provide the whole code, I assume the data structure is like:

    struct tree;
    typedef struct tree {
        struct tree *left;
        struct tree *right;
        int value;
    } Tree;
    

    Each node contains its value, and the pointers to its children.
    The left child is smaller than parent.
    The right child is larger than parent.

  2. Create the tree structure
    For the simplicity, I just write a messy ugly function in order to insert new node into the binary tree.

    void addnode(Tree *parent, int value)
    {
        Tree *node = malloc(sizeof(*node));
    
        node->left = NULL;
        node->right = NULL;
        node->value = value;
    
    /* Recursive check until the new node can be the direct child of the parent node */
    check:
        if (value < parent->value) { /* new node belongs to left sub-tree */
            if (parent->left) { /* sub-tree is not null, so dig into it */
                parent = parent->left;
                goto check;
            }
            parent->left = node;
        } else {
            if (parent->right) { /* sub-tree is not null, so dig into it */
                parent = parent->right;
                goto check;
            }
            parent->right = node;
        }
    }
    

    Note that:

    • Any child node whose value is smaller than root node will locate in the left sub-tree of the root node.
    • Any child node whose value is larger than root node will locate in the right sub-tree of the root node. In the main function, we just declare a root node, and add the other nodes in order :
    int main()
    {
        Tree root = {
            .left = NULL,
            .right = NULL,
            .value = 9,
        };
    
        addnode(&root, 3);
        addnode(&root, 6);
        addnode(&root, 5);
        addnode(&root, 8);
        addnode(&root, 13);
        addnode(&root, 14);
    /* ... */
        return 0;
    }
    
  3. Provide the findSmallest function
    Its subroutine, the length function, returns the number of nodes (including root node itself) belong to the specified tree

    int length(Tree *t) {
        if (t == NULL) {
            return 0;
        } else {
            return 1 + length(t->left) + length(t->right); 
        }
    }
    

    Here, the parameter t is corrected to pointer type, so we can use -> to access member through pointer
    And the final part, findSmallest function :

    int findSmallest(Tree *t, int k) {
        if (t == NULL) { /* this case should not happen */
            return -1; 
        }
    
        if (k == length(t->left)) {
            return t->value;
        }
    
        if (k < length(t->left)) {
            return findSmallest(t->left, k); 
        }
    
        if (k > length(t->left)) {
            return findSmallest(t->right, (k - length(t->left) - 1));  
        }
    
        return 0; 
    }
    

    The parameter t is also corrected to pointer type.
    length(t->left) represents the number of nodes whose value is less than t->value, because :

    • The function length(root) returns the number of nodes in the whole tree whose root node is root.
    • In the tree of this question, all left sub(child) nodes are smaller than the root node.
    • In the tree of this question, all right sub(child) nodes are larger than the root node.

    So, if length(t->left) is 0, t is the smallest node inside the tree whose root node is t.

    The test function call is :

    int main() {
    /* ... */
        for (int k = 0; k < 6; k++) {
            int ret = findSmallest(&root, k);
            printf("k = %d, returned %d\n", k, ret);
        }
        return 0;
    }
    
    k = 0, returned 3
    k = 1, returned 5
    k = 2, returned 6
    k = 3, returned 8
    k = 4, returned 9
    k = 5, returned 13
    
imjlfish
  • 59
  • 5