提交 5a3a683b authored 作者: James Bergstra's avatar James Bergstra

added comments and tests for RandomFunction

上级 cbe7fef9
...@@ -23,6 +23,7 @@ class RandomStateType(gof.Type): ...@@ -23,6 +23,7 @@ class RandomStateType(gof.Type):
""" """
def __str__(self): def __str__(self):
return 'RandomStateType' return 'RandomStateType'
def filter(self, data, strict=False): def filter(self, data, strict=False):
if self.is_valid_value(data): if self.is_valid_value(data):
return data return data
...@@ -115,7 +116,6 @@ class RandomFunction(gof.Op): ...@@ -115,7 +116,6 @@ class RandomFunction(gof.Op):
the random draw. the random draw.
""" """
args = map(tensor.as_tensor, args)
if shape == () or shape == []: if shape == () or shape == []:
shape = tensor.lvector() shape = tensor.lvector()
else: else:
...@@ -127,20 +127,33 @@ class RandomFunction(gof.Op): ...@@ -127,20 +127,33 @@ class RandomFunction(gof.Op):
print >> sys.stderr, 'WARNING: RandomState instances should be in RandomStateType' print >> sys.stderr, 'WARNING: RandomState instances should be in RandomStateType'
if 0: if 0:
raise TypeError('r must be RandomStateType instance', r) raise TypeError('r must be RandomStateType instance', r)
# assert shape.type == tensor.lvector doesn't work because we want to ignore the # the following doesn't work because we want to ignore the broadcastable flags in
# broadcastable vector # shape.type
assert len(args) <= len(self.args) # assert shape.type == tensor.lvector
# convert args to Tensor instances
# and append enough None's to match the length of self.args
args = map(tensor.as_tensor, args)
if len(args) > len(self.args):
raise TypeError('Too many args for this kind of random generator')
args += (None,) * (len(self.args) - len(args)) args += (None,) * (len(self.args) - len(args))
assert len(args) == len(self.args)
# build the inputs to this Apply by overlaying args on self.args
inputs = [] inputs = []
for arg, default in zip(args, self.args): for arg, default in zip(args, self.args):
assert arg is None or default.type.dtype == arg.type.dtype assert arg is None or default.type.dtype == arg.type.dtype
input = default if arg is None else arg input = default if arg is None else arg
inputs.append(input) inputs.append(input)
return gof.Apply(self, return gof.Apply(self,
[r, shape] + inputs, [r, shape] + inputs,
[r.type(), self.outtype()]) [r.type(), self.outtype()])
def perform(self, node, inputs, (rout, out)): def perform(self, node, inputs, (rout, out)):
# Use self.fn to draw shape worth of random numbers.
# Numbers are drawn from r if self.inplace is True, and from a copy of r if
# self.inplace is False
r, shape, args = inputs[0], inputs[1], inputs[2:] r, shape, args = inputs[0], inputs[1], inputs[2:]
assert type(r) == numpy.random.RandomState assert type(r) == numpy.random.RandomState
r_orig = r r_orig = r
......
...@@ -16,7 +16,6 @@ class T_random_function(unittest.TestCase): ...@@ -16,7 +16,6 @@ class T_random_function(unittest.TestCase):
assert getattr(rf, 'destroy_map', {}) == {} assert getattr(rf, 'destroy_map', {}) == {}
rng_R = random_state_type() rng_R = random_state_type()
print rng_R
post_r, out = rf(rng_R, (4,)) post_r, out = rf(rng_R, (4,))
...@@ -37,8 +36,58 @@ class T_random_function(unittest.TestCase): ...@@ -37,8 +36,58 @@ class T_random_function(unittest.TestCase):
assert rf.inplace assert rf.inplace
assert getattr(rf, 'destroy_map', {}) != {} assert getattr(rf, 'destroy_map', {}) != {}
def test_args(self):
"""Test that arguments to RandomFunction are honored"""
rf2 = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, -2.0, 2.0)
rf4 = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, -4.0, 4.0,
inplace=True)
rng_R = random_state_type()
# use make_node to override some of the self.args
post_r2, out2 = rf2(rng_R, (4,))
post_r2_4, out2_4 = rf2(rng_R, (4,), -4.0)
post_r2_4_4, out2_4_4 = rf2(rng_R, (4,), -4.0, 4.0)
post_r4, out4 = rf4(rng_R, (4,))
f = compile.function(
[compile.In(rng_R, value=numpy.random.RandomState(55), update=post_r4, mutable=True)],
[out2, out4, out2_4, out2_4_4],
accept_inplace=True)
f2, f4, f2_4, f2_4_4 = f()
f2b, f4b, f2_4b, f2_4_4b = f()
assert numpy.allclose(f2*2, f4)
assert numpy.allclose(f2_4_4, f4)
assert not numpy.allclose(f4, f4b)
def test_inplace_optimization(self): def test_inplace_optimization(self):
print >> sys.stderr, "WARNING NOT IMPLEMENTED T_random_function.test_inplace_optimization" """Test that arguments to RandomFunction are honored"""
#inplace = False
rf2 = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, -2.0, 2.0)
rng_R = random_state_type()
# use make_node to override some of the self.args
post_r2, out2 = rf2(rng_R, (4,))
f = compile.function(
[compile.In(rng_R,
value=numpy.random.RandomState(55),
update=post_r2,
mutable=True)],
out2,
mode='FAST_RUN') #DEBUG_MODE can't pass the id-based test below
# test that the RandomState object stays the same from function call to function call,
# but that the values returned change from call to call.
id0 = id(f[rng_R])
val0 = f()
assert id0 == id(f[rng_R])
val1 = f()
assert id0 == id(f[rng_R])
assert not numpy.allclose(val0, val1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论