Tuesday, June 2, 2009

Python multidimensional sparse array of objects

Posted by Danny Tarlow
A research project I'm working on requires that I store a ton of objects. Each object can be identified by a 4-dimensional index, but for any given execution, I only need a small (different) subset of the indexes to be used. I was surprised not to find something that fits my needs in numpy, scipy, or by Googling (maybe I didn't look hard enough).

Anyhow, I grabbed the blist module and built a simple multidimensional version on top of it. It isn't fancy, but it does the trick for me.

Let me know if you know of any more fleshed-out packages that can do this.
from blist import *

class SparseArray():
    def __init__(self, dims):
        self.dims = dims
        self.size = 1
        for d in self.dims:
            self.size *= d
        self.base_list = self.size * blist([0])

    def base_index(self, index):
        """ Convert N-d index to index in base list. """
        result = 0
        multiplier = 1
        for i in range(len(index)):
            result += index[i] * multiplier
            multiplier *= self.dims[i]
        if result >= self.size:
            raise IndexError("Index out of range %s (size: %s)" % (result,
        return result

    def get(self, index):
        return self.base_list[self.base_index(index)]

    def set(self, index, val):
        self.base_list[self.base_index(index)] = val

if __name__ == "__main__":
    s = SparseArray((100, 100, 100, 100))
    s.set((0, 0, 0, 0), ['first'])
    s.set((23, 33, 10, 5), 10)
    s.set((99, 99, 99, 99), ['last'])
    print s.get((0, 0, 0, 0))
    print s.get((23, 33, 10, 5))
    print s.get((99, 99, 99, 98))
    print s.get((99, 99, 99, 99))
        print s.get((99, 99, 99, 100))
        print "Properly caught index error"

No comments: