提交 60eb8ce2 authored 作者: James Bergstra's avatar James Bergstra

added __eq__ and __hash__ to NumpyGenerator

上级 8085810d
......@@ -68,6 +68,19 @@ class T_Random(unittest.TestCase):
self.failUnless(str(v0[0,0]).startswith('0.013259'))
self.failUnless(str(v0[1,2]).startswith('0.753368'))
def test5(self):
"""Test that two NumpyGenerators with the same dist compare equal"""
rng0 = RandomState(123456)
rng1 = RandomState(123456)
d0 = rng0.gen(('beta',{'a':0.5,'b':0.65}), (2,3,4))
d1 = rng1.gen(('beta',{'a':0.5,'b':0.65}), (2,3,4))
self.failUnless(d0.owner.op == d1.owner.op)
self.failUnless(hash(d0.owner.op) == hash(d1.owner.op))
if __name__ == '__main__':
unittest.main()
import os, sys
import scipy.weave as weave
import gof.utils
"""
File: omega/blas.py
......@@ -794,16 +795,7 @@ def blas_proto():
}
"""
def _constant(f):
"""Return a function that always returns its first call value
"""
def rval(*args, **kwargs):
if not hasattr(f, 'rval'):
f.rval = f(*args, **kwargs)
return f.rval
return rval
@_constant
@gof.utils.memoize
def ldflags():
"""Return a list of libraries against which an Op's object file should be
linked to benefit from a BLAS implementation.
......
......@@ -36,6 +36,21 @@ class scratchpad:
print "scratch" + str(self.__dict__)
def memoize(f):
"""Cache the return value for each tuple of arguments
(which must be hashable) """
cache = {}
def rval(*args, **kwargs):
kwtup = tuple(kwargs.items())
key = (args, kwtup)
if key not in cache:
val = f(*args, **kwargs)
cache[key] = val
else:
val = cache[key]
return val
return rval
def deprecated(filename, msg=''):
......
......@@ -3,16 +3,23 @@ import tensor
import numpy
import functools
def fn_from_dist(dist):
# the optional argument implements a closure
# the cache is used so that we we can be sure that
# id(self.fn) in NumpyGenerator identifies
# the computation performed.
def fn_from_dist(dist, cache={}):
if callable(dist):
return dist
if isinstance(dist, str):
return getattr(numpy.random.RandomState, dist)
name, kwargs = dist
fn = getattr(numpy.random.RandomState, name)
fn = functools.partial(fn, **kwargs)
return fn
key = (name, tuple(kwargs.items()))
if key not in cache:
fn = getattr(numpy.random.RandomState, name)
fn = functools.partial(fn, **kwargs)
cache[key] = fn
return cache[key]
class RandomState(object):
def __init__(self, seed):
......@@ -45,14 +52,22 @@ class NumpyGenerator(gof.op.Op):
self.ndim = ndim
self.fn = fn
def __eq__(self, other):
return (type(self) is type(other))\
and self.__class__ is NumpyGenerator \
and self.seed == other.seed \
and self.ndim == other.ndim \
and self.fn == other.fn
def __hash__(self):
return self.seed ^ self.ndim ^ id(self.fn)
def make_node(self, _shape):
#TODO: check for constant shape, and guess the broadcastable bits
shape = tensor.convert_to_int64(_shape)
if shape.type.ndim != 1:
raise TypeError('shape argument was not converted to 1-d tensor', _shape)
inputs = [gof.Value(gof.type.generic, numpy.random.RandomState(self.seed)), shape]
outputs = [gof.Result(tensor.Tensor(dtype='float64', broadcastable =
[False]*self.ndim))]
outputs = [tensor.Tensor(dtype='float64', broadcastable = [False]*self.ndim).make_result()]
return gof.Apply(op = self, inputs = inputs, outputs = outputs)
def grad(self, inputs, grad_outputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论