提交 735cb923 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5269 from Saizheng/master

Make mrg_uniform_base to be broadcastable
...@@ -317,9 +317,13 @@ class mrg_uniform_base(Op): ...@@ -317,9 +317,13 @@ class mrg_uniform_base(Op):
# this op should not be called directly. # this op should not be called directly.
# #
# call through MRG_RandomStreams instead. # call through MRG_RandomStreams instead.
broad = []
for i in range(self.output_type.ndim):
broad.append(tensor.extract_constant(size[i]) == 1)
output_type = self.output_type.clone(broadcastable=broad)()
return Apply(self, return Apply(self,
[rstate, size], [rstate, size],
[rstate.type(), self.output_type()]) [rstate.type(), output_type])
def grad(self, inputs, ograd): def grad(self, inputs, ograd):
return [gradient.grad_undefined(self, k, inp, return [gradient.grad_undefined(self, k, inp,
......
...@@ -627,6 +627,22 @@ def test_uniform(): ...@@ -627,6 +627,22 @@ def test_uniform():
allow_01=True, inputs=input) allow_01=True, inputs=input)
def test_uniform_broadcastable():
x = tensor.matrix()
size1 = (10, 1)
size2 = (x.shape[0], 1)
R = MRG_RandomStreams(234, use_cuda=False)
# check when all dimensions are constant
uu = R.uniform(size=size1)
assert uu.broadcastable == (False, True)
# check when some dimensions are theano variables
uu = R.uniform(size=size2)
assert uu.broadcastable == (False, True)
@attr('slow') @attr('slow')
def test_binomial(): def test_binomial():
# TODO: test size=None, ndim=X # TODO: test size=None, ndim=X
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论