LintCode/Remove Node In Binary Search Tree

Problem Summary

Given a root of Binary Search Tree with unique value for each node. Remove the node with given value V. If there is no such a node with given value in the binary search tree, do nothing. Remember to keep the tree a binary search tree after removal.

Solution

Let us solve this recursively.

For current root,

  1. If we have root->val < V, then V is in the right subtree of root or it does not exist. So we call removeNode(root->right,V) and return root.

  2. Similarly, if root->val > V, we call removeNode(root->left,V) and return root.

  3. If we have root->val == V, we need to remove root.
       (1) If one of root->left or root->right is NULL, we should replace root with the other one and return the new root.
       (2) If neither of root->left and root->right are NULL, we need to find the precursor of root in the inorder traversal of the tree, i.e. the maximum number in the left subtree of root, and replace root with it. In the mean time, we also need to keep the tree a BST.

About the implementation of (2):

At first I wrote a function “get_precursor” for it. It finds the precursor and removes it from the tree. But it requires a little cooperation of the upper-layer function (see line 45, Code 1). Then I realized it is not necessary. So I wrote a function “get_max” which also can get the precursor but does not change the tree. With this I just need to call get_max(root->left) and then call removeNode(root->left,precursor) to remove the precursor from the left subtree (see Code 2). The new method is easier to code and makes the code more robust and explicit. How amazing recursion is!

Code 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/**
* Definition of TreeNode:
* class TreeNode {
* public:
* int val;
* TreeNode *left, *right;
* TreeNode(int val) {
* this->val = val;
* this->left = this->right = NULL;
* }
* }
*/
class Solution {
public:
/*
* @param root: The root of the binary search tree.
* @param value: Remove the node with given value.
* @return: The root of the binary search tree after removal.
*/
int get_precursor(TreeNode *cur,TreeNode *parent)
{
if (cur->right != NULL)
return get_precursor(cur->right,cur);
if (parent != NULL)
parent->right = cur->left;
return cur->val;
}
TreeNode * removeNode(TreeNode * root, int value) {
if (root == NULL)
return NULL;
if (root->val == value)
{
if (root->left == NULL)
return root->right;
if (root->right == NULL)
return root->left;
int pre = get_precursor(root->left,NULL);
if (pre == root->left->val)
root->left = root->left->left;
root->val = pre;
}
else
{
if (root->val < value)
root->right = removeNode(root->right,value);
else
root->left = removeNode(root->left,value);
}
return root;
}
};

Code 2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Solution {
public:
/*
* @param root: The root of the binary search tree.
* @param value: Remove the node with given value.
* @return: The root of the binary search tree after removal.
*/
int get_max(TreeNode *cur)
{
if (cur->right != NULL)
return get_max(cur->right);
return cur->val;
}
TreeNode * removeNode(TreeNode * root, int value) {
if (root == NULL)
return NULL;
if (root->val == value)
{
if (root->left == NULL)
return root->right;
if (root->right == NULL)
return root->left;
int pre = get_max(root->left);
root->val = pre;
root->left = removeNode(root->left,pre);
}
else
{
if (root->val < value)
root->right = removeNode(root->right,value);
else
root->left = removeNode(root->left,value);
}
return root;
}
};