Segment Trees

As some of you might be knowing, Segment Trees are special type of data structure for query and updating of intervals of array in logarithmic time.Basically segment trees store data of specific intervals of array in tree form. The root contains the whole array, it’s left child contain data of start index to middle index and right child contain data of middle index +1 to end index and so on. So the leaf nodes contain data of specific element of array.

Now the data of element I have been mentioning can be anything like the sum of elements or the max element or the min element etc. I have written the code in java for data containing sum of elements. A segment tree for an array of n elements uses O(n log n) storage and can be built in O(n log n) time. Segment trees support searching for all the intervals that contain a query point in O(log n + k), k being the number of queried intervals.

The segment tree is stored in a form of heap so left child of node i is 2*i and right child of node i is 2*i+1. To build the segment tree for sum of particular range problem, I have recursively called the method for child nodes. The moment it is the leaf node, it will be the base case of the recursion so the parent will get the value of sum of its childrens values. In this way the whole tree will be created from bottom up.


private int[] treeNode;
private int maxsize;
private int height;

private final int STARTINDEX = 0;
private final int ENDINDEX;
private final int ROOT = 0;

public SegmentTree(int elements[])
{
  height = (int)(Math.ceil(Math.log(elements.length) / Math.log(2))); //height of segment tree is O(log(n))
  maxsize = 2 * (int) Math.pow(2, height) - 1;  //setting maxsize
  treeNode = new int[maxsize];
  ENDINDEX = elements.length - 1; //setting global variable to size of array given
  constructSegmentTree(elements, STARTINDEX, ENDINDEX, ROOT);  // calling method to construct tree from the array
}

private int getLeftChild(int pos){
   return 2 * pos + 1;
}

private int rightchild(int pos){
   return 2 * pos + 2;
}

private int constructSegmentTree(int[] elements, int startIndex, int endIndex, int current)
{
if (startIndex == endIndex) //base case or leaf node
{
   treeNode[current] = elements[startIndex];
   return treeNode[current];
}
int mid = (startIndex + (endIndex - startIndex) / 2);
treeNode[current] = constructSegmentTree(elements, startIndex, mid, getLeftChild(current))
+ constructSegmentTree(elements, mid + 1, endIndex, rightchild(current));
return treeNode[current];  // calling it recusively and setting the current node's value to sum of its children
}

Here’s the method to get result of query.

private int getSum(int startIndex, int endIndex, int queryStart, int queryEnd, int current)
{
if (queryStart <= startIndex && queryEnd >= endIndex )  // base case
   return treeNode[current];

if (endIndex < queryStart || startIndex > queryEnd)  // current node is out of range
   return 0;

int mid = (startIndex + (endIndex - startIndex) / 2);
return getSum(startIndex, mid, queryStart, queryEnd, getLeftChild(current))
+ getSum( mid + 1, endIndex, queryStart, queryEnd, rightchild(current));  //recursively calling the query method and getting the result
}

public int query(int queryStart, int queryEnd)
{
if(queryStart < 0 || queryEnd > treeNode.length)  // if the query is out of range
  return -1;

return getSum(STARTINDEX, ENDINDEX, queryStart, queryEnd, ROOT);
}

Here’s the code for updating the tree

private void updateTree(int startIndex, int endIndex, int updatePos, int update, int current)
{
if ( updatePos < startIndex || updatePos > endIndex) //update pos out of range
   return;
treeNode[current] = treeNode[current] + update;  // if current node comes under the range to update, update it first and then call the method on its children
if (startIndex != endIndex)
{
  int mid = (startIndex + (endIndex - startIndex) / 2);
  updateTree(startIndex, mid, updatePos, update, getLeftChild(current));
  updateTree(mid+1, endIndex, updatePos, update, rightchild(current));
}
}

public void update(int update, int updatePos, int[] elements)  // This method first calculates the diff to be added in each of the required nodes and then calls the method on the root of the tree
{
   int updatediff = update - elements[updatePos] ;
   elements[updatePos] = update;  //the elements of the array are updated first
   updateTree(STARTINDEX, ENDINDEX, updatePos, updatediff, ROOT);  
}

This is the test run

 public static void main(String args[])
	    {
	        int[] elements = {1,2,3,4,5,6,7,8,9};
	        SegmentTree segmentTree = new SegmentTree(elements); //creating the segment tree
	        int sum = segmentTree.query(1, 4);  //querying for sum of elements in range 1-4
	 
	        System.out.println("the sum is " + sum);
	        segmentTree.update(3, 1,elements); // updating the tree
	        sum = segmentTree.query(1, 4); //getting the updated result
	        System.out.println("the sum is " + sum);	
	    }  	

And the output is

the sum is 14
the sum is 15

Segment tree stores cumulative values of all intervals of the array and each interval can be accessed in logarithmic time, segment tree can be very helpful for problems like range min or max or range sum which have large amount of queries.
Also if there is large amount of updates given in the problem the segment tree may not survive, for that lazy propagation comes in handy which I will discuss in my next blog.
Here is the link to 60 problems relating to segment trees,try them to get a hold of the topic
problems on segment trees

By the way,this is my first blog so feel free to write any suggestions in comments.

Keep learning!

Parth Panchal