Tuesday, June 2, 2009

Python multidimensional sparse array of objects

Posted by Danny Tarlow
A research project I'm working on requires that I store a ton of objects. Each object can be identified by a 4-dimensional index, but for any given execution, I only need a small (different) subset of the indexes to be used. I was surprised not to find something that fits my needs in numpy, scipy, or by Googling (maybe I didn't look hard enough).

Anyhow, I grabbed the blist module and built a simple multidimensional version on top of it. It isn't fancy, but it does the trick for me.

Let me know if you know of any more fleshed-out packages that can do this.
from blist import *

class SparseArray():
    def __init__(self, dims):
        self.dims = dims
        
        self.size = 1
        for d in self.dims:
            self.size *= d
        
        self.base_list = self.size * blist([0])


    def base_index(self, index):
        """ Convert N-d index to index in base list. """
        
        result = 0
        multiplier = 1
        for i in range(len(index)):
            result += index[i] * multiplier
            multiplier *= self.dims[i]
        
        if result >= self.size:
            raise IndexError("Index out of range %s (size: %s)" % (result,
                                                                   self.size))
        
        return result


    def get(self, index):
        return self.base_list[self.base_index(index)]


    def set(self, index, val):
        self.base_list[self.base_index(index)] = val


if __name__ == "__main__":
    s = SparseArray((100, 100, 100, 100))
    
    s.set((0, 0, 0, 0), ['first'])
    s.set((23, 33, 10, 5), 10)
    s.set((99, 99, 99, 99), ['last'])
    
    print s.get((0, 0, 0, 0))
    print s.get((23, 33, 10, 5))
    print s.get((99, 99, 99, 98))
    print s.get((99, 99, 99, 99))
    
    try:
        print s.get((99, 99, 99, 100))
    except:
        print "Properly caught index error"

No comments: