提交 10701e21 authored 作者: Saizheng Zhang's avatar Saizheng Zhang

make mrg_uniform_base to be broadcastable

上级 2656747a
...@@ -318,9 +318,9 @@ class mrg_uniform_base(Op): ...@@ -318,9 +318,9 @@ class mrg_uniform_base(Op):
# #
# call through MRG_RandomStreams instead. # call through MRG_RandomStreams instead.
broad=[] broad=[]
for i in self.output_type.ndim: for i in range(self.output_type.ndim):
broad.append(T.extract_constant(size[i]) == 1) broad.append(tensor.extract_constant(size[i]) == 1)
output_type = self.output_type.clone(broadcastab le=broad)() output_type = self.output_type.clone(broadcastable=broad)()
return Apply(self, return Apply(self,
[rstate, size], [rstate, size],
[rstate.type(), output_type]) [rstate.type(), output_type])
......
...@@ -627,6 +627,22 @@ def test_uniform(): ...@@ -627,6 +627,22 @@ def test_uniform():
allow_01=True, inputs=input) allow_01=True, inputs=input)
def test_uniform_broadcastable():
x = tensor.matrix()
size1 = (10, 1)
size2 = (x.shape[0], 1)
R = MRG_RandomStreams(234, use_cuda=False)
# check when all dimensions are constant
uu = R.uniform(size=size1)
assert uu.broadcastable == (False, True)
# check when some dimensions are theano variables
uu = R.uniform(size=size2)
assert uu.broadcastable == (False, True)
@attr('slow') @attr('slow')
def test_binomial(): def test_binomial():
# TODO: test size=None, ndim=X # TODO: test size=None, ndim=X
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论