Random Node

Lessons Learned: binary search, randomization

Before diving in, I want to clarify on my previous post, Check Subtree - the second return statement of the second function (subTree) should be return true, not return false. Apologies for any confusion on this!

Question 9: Random Node: You are implementing a binary search tree class from scratch, which, in addition to insert, and find, has a method getRandomNode() which returns a random node from the tree. All nodes should be equally likely to be chosen. Design and implement an algorithm for getRandomNode, and explain how you would implement the rest of the methods.

One important thing to note here is what the question is actually asking. We don’t just have to implement a method called getRandomNode() - we have to implement a binary search tree class from scratch. This is a hint that we may need to modify the internal data structure in order to solve the problem.

One immediate solution that may come to mind is simply taking all of the nodes of the tree in an array, and returning a random node in this array. However it may be apparent that we’re looking for a better solution as we most likely are given a tree for a reason, instead of an array.

Another, more efficient way is to maintain an array, that dynamically adds and removes nodes as they are added/removed in the tree. Although this is better, removing a node from the array takes time as you must find the node in the array, which takes linear time. (More information on linear time here).

Let’s step away from the array idea for a sec. What if we knew the depth of the tree? Since we’re building our tree from scratch we can keep this value known as a class variable. We can pick a random depth, then traverse left/right randomly until we go to that depth. This wouldn’t actually work, because the depths of the tree vary from left to right subtrees, so there is not an equal likelihood that any node will be chosen.

Let’s think about this a bit more, and start from just the root of the tree. What is the likelihood that we return the root? Since we have N nodes (let’s just say), the chances are 1/N that we return the root. This is actually true for every node, a 1/N chance that that node is chosen

So now we’ve solved for the root. But what about left and right subtrees? Should it be 50/50 chance that the node is on the left or right? No. Even if the tree is balanced, the number of nodes can differ on left versus right. The correct way to think about it is, the chance that we go to the left subtree is the sum of all the probabilities of reaching each node on the left side. In other words, LEFT_SIZE * 1/N.

Likewise, the odds of going to the right subtree are RIGHT_SIZE * 1/N.

This means that when we create our data structure, each node must know the size of the nodes on the left, and the size of the nodes on the right. Lucky for us, we’re making the tree from scratch! We can easily store a variable, say size, in each node. We increment size on inserts, and decrement on deletes.

The code below implements this solution:

class TreeNode {
  private int data;
  public TreeNode left;
  public TreeNode right;
  private int size = 0;

  public TreeNode(int d) {
    data = d;
    size = 1;
  }

  public TreeNode getRandomNode() {
    int leftSize = left == null ? 0 : left.size();
    Random random = new Random();
    int index = random.nextInt(size);
    if(index < leftSize) {
      return left.getRandomNode();
    } else if(index == leftSize) {
      return this;
    } else {
      return right.getRandomNode();
    }
  }

  public void insertInOrder(int d) {
    if(d <= data) {
      if(left == null) {
        left = new TreeNode(d);
      } else {
        left.insertInOrder(d);
      }
    } else {
      if(right == null) {
        right = new TreeNode(d);
      } else {
        right.insertInOrder(d);
      }
    }
    size++;
  }

  private int size() { return size; }
  private int data() { return data; }

  public TreeNode find(int d) {
    if(d == data) {
      return this;
    } else if (d < data) {
      return (left == null) ? left.find(d) : null;
    } else {
      return (right == null) ? right.find(d) : null;
    }
    return null;
  }
}

We’re done!

Similar questions: