Monday, November 03, 2008

python interval tree

EDIT: added a couple points inline.

I'm obsessed with trees lately -- of the CS variety, not the plant variety. Although we are studying poplar, so I'll be using trees to study trees.
I'd tried a couple times to implement an interval tree from scratch following the Wikipedia entry.
Today I more or less did that in python. It's the simplest possible form. There's no insertion (though that's trivial to add), it just takes a list of 'things' with start and stop attributes and creates a tree with a .find() method.
The wikipedia entry that baffled me was about storing 2 copies of each node's intervals--one sorted by start and the other by stop. I didn't do that as I think in most cases it won't improve lookup time. I figure if you have 1 million elements and a tree of depth 16, then you have on average 15 intervals per node (actually fewer since there are the non-leaf nodes). So I just brute force each of those nodes and move to the next. I think that increases the worst-case, but makes no difference in actual search time--with the benefit of halving storage.

EDIT: now the version in my repo keeps the intervals sorted by start, so it can avoid doing the brute for search at each node during a search when search.stop < node.intervals[0].start. This did improve performance.

The tree class takes a list of intervals and calculates a center point. From there it partitions them into left, overlapping, and right in terms of their relation to the center point. Overlapping are assigned to the current node, and left and right are recursively partitioned in that fashion until there are only `minbucket` intervals per node, or the specified `depth` has been reached AND there are fewer intervals than `maxbucket`. So a tree can have a greater `depth` than requested if it would otherwise have more than `maxbucket` intervals in a single node. The Wikipedia version doesn't have maxbucket or minbucket...

EDIT: the maxbucket actually only works on leaf-nodes, and has no effect otherwise.

I'm sure that's painfully obvious for anyone who's ever taken a CS course, but it was foggy at best for me until I implemented. Below is the entire implementation:

class IntervalTree(object):
__slots__ = ('intervals', 'left', 'right', 'center')

def __init__(self, intervals, depth=16, minbucket=96, _extent=None, maxbucket=4096):

depth -= 1
if (depth == 0 or len(intervals) < minbucket) and len(intervals) > maxbucket:
self.intervals = intervals
self.left = self.right = None
return

left, right = _extent or \
(min(i.start for i in intervals), max(i.stop for i in intervals))
center = (left + right) / 2.0


self.intervals = []
lefts, rights = [], []


for interval in intervals:
if interval.stop < center:
lefts.append(interval)
elif interval.start > center:
rights.append(interval)
else: # overlapping.
self.intervals.append(interval)

self.left = lefts and IntervalTree(lefts, depth, minbucket, (left, center)) or None
self.right = rights and IntervalTree(rights, depth, minbucket, (center, right)) or None
self.center = center


def find(self, start, stop):
"""find all elements between (or overlapping) start and stop"""
overlapping = [i for i in self.intervals if i.stop >= start
and i.start <= stop]

if self.left and start <= self.center:
overlapping += self.left.find(start, stop)

if self.right and stop >= self.center:
overlapping += self.right.find(start, stop)

return overlapping

Only 45 lines of code. I had added a couple extra attributes so that searching could do fewer checks, but it only improved performance by ~15% and I liked the simplicity. One way to improve the search speed, and the distribution on skewed data would be to sort the intervals at the top node, so they'd then be sorted for all other nodes. Then instead of using center = (left + right)/2, It'd could use the center point of the center interval at each node. That would also allow short-circuiting the brute-force search at the top of the find method with something like:

if not (start > self.intervals[-1].stop and stop < self.intervals[0].start):
overlapping = [ .. list comprehension ]

But all told, that adds 5 or so lines of code. Oh, and depending on how it's used, it's between 15 and 25 times faster than brute-force search.

EDIT: I added the above check, but it can only do the 2nd comparison "stop < self.intervals.start as the first is invalid given a very long interval. Regarding speed, the smaller the search window, the better the performance improvement. The code is now > 20 times as fast brute force for a very (speaking in terms of looking for genomic features) large swath of 100K. with a search space of 50K, it's 50+ times as fast as linear search.

The full code (including a docstring with homer simpson quote) is in my google code repo. If I've made obvious mistakes or you have improvements, I'd be glad to know them.

12 comments:

Bao said...

cool ... i wonder if it is more difficult to allow 'add' and 'remove' method for a dynamic IntervalTree

brentp said...

to add, just descend the tree and if the new interval.stop < node.center, go left. if interval.start > node.center, go right. otherwise, append the interval to node.intervals.
heheh.
since there's no balancing, it's easy.
but, yeah, that's definitely not optimal.

the tree here in bx-python does all that properly. I use that for some stuff--and i have a cython version of that on github. Though my pure python tree above is pretty quick as well.

Anonymous said...

hi,

i am using your code but cannot make it faster than linear search. i used the code here: http://groups.google.com/group/comp.lang.python/browse_thread/thread/9b257dc4e658ad03#
any advice on this would very much help me. thanks you.
-per

brentp said...

per, i replied to your thread. simply put, this (and most) trees will perform poorly as the length of the query interval increases.
check out bx-python's quicksect,
or my cython version here:
https://github.com/brentp/quicksect/tree
if you need a more robust implementation.

Anonymous said...

Just to make it a little bit faster:


"overlapping = [i for i in self.intervals if i.stop >= start and i.start <= stop]"

could be

"
overlapping = []
for i for i in self.intervals:
....if i.start<=stop:
........if i.stop >= start:
............overlapping.append(i)
........else:
............pass
....else:
........break"

brentp said...

@anon:
that makes sense, but are you sure that would make it faster? i suspect in most cases it is faster to do the list comprehension because you dont have the cost of an append each time.

pzs said...

Many thanks for this. I'm filing affymetrix tiles (20 bases wide) into bins across the human genome. I have 40 million points to file into about 130,000 bins and the tree improved the speed from more than 10 hours to less than 90 minutes.

I have made a few tweaks. If you inadvertently feed it an interval where the upper bound is less than the lower bound, it gives an enormous runtime exception, so I added an assert to check for that. I also added getstate and setstate functions so that I can pickle them:

def __getstate__(self):
return { 'intervals' : self.intervals,
'left' : self.left,
'right' : self.right,
'center' : self.center }

def __setstate__(self, state):
for key,value in state.items():
setattr(self, key, value)

brentp said...

pzs, that's very cool! glad to hear it's useful. i'm surprised you dont run out of memory with that many intervals!

that's an awesome example of get/set-state, so i added it to the code:
http://code.google.com/p/bpbio/source/detail?spec=svn103&r=103

i suppose it's unneeded if this works for you, but also check out the version in bx-python -- http://bitbucket.org/james_taylor/bx-python/src/d9c88c9359a0/lib/bx/intervals/intersection.pyx


it's in cython, so it's _very_ fast. you might get your times down below 10 minutes.

pzs said...

It runs in around 4.5GB memory.
I'm currently running a bootstrap on this data - you randomise the point positions and file them 100 times and then t-test against the value from your data. Yes, that means 100 iterations of the 90 minute filing!

Thanks for the tip about bx-trees.
If bx-trees get it down to 10 minutes that will make a huge difference to us.

Incidentally, despite what I said about the pickling I haven't ended up using that. My intervals are stored in csv files and the Python csv module is so fast that it's pretty quick to just build them fresh each time.

pzs said...

It comes down to around 6 1/2 minutes :)))

Daniel Standage said...

From the description in your blog post, it appears that you want to check len(intervals) < maxbucket (on line 44 of the most recent revision r106). If the number of intervals is greater than or equal to maxbucket, you want to keep building subtrees right?

brentp said...

@Daniel, yes, fixed now. oddly, someone else just pointed this out yesterday.