提交 4c8785f0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix bug in patternbroadcast with empty pattern

make_node used to crash, now it works. A test detects the problem.
上级 7db4620f
...@@ -4079,7 +4079,7 @@ class Rebroadcast(Op): ...@@ -4079,7 +4079,7 @@ class Rebroadcast(Op):
broadcast_pattern[k] = str(int(v)) broadcast_pattern[k] = str(int(v))
return '%s{%s}' % (self.__class__.__name__, ','.join(broadcast_pattern)) return '%s{%s}' % (self.__class__.__name__, ','.join(broadcast_pattern))
def make_node(self, x): def make_node(self, x):
if x.ndim <= numpy.max(self.axis.keys()): if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())):
raise ValueError('Trying to rebroadcast nonexistant dimension') raise ValueError('Trying to rebroadcast nonexistant dimension')
t = x.type.__class__(dtype = x.type.dtype, t = x.type.__class__(dtype = x.type.dtype,
broadcastable = [self.axis.get(i, b) broadcastable = [self.axis.get(i, b)
......
...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile) tile, patternbroadcast)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -5267,6 +5267,15 @@ class test_broadcast(unittest.TestCase): ...@@ -5267,6 +5267,15 @@ class test_broadcast(unittest.TestCase):
assert addbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x assert addbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,0),0) is x assert addbroadcast(unbroadcast(x,0),0) is x
def test_patternbroadcast(self):
# Test that patternbroadcast with an empty broadcasting pattern works
x = scalar('x')
m = tensor.matrix('m')
s = patternbroadcast(m, x.broadcastable)
assert s is m
x2 = patternbroadcast(x, x.broadcastable)
assert x2 is x
def test_infer_shape(self): def test_infer_shape(self):
x = matrix() x = matrix()
y = addbroadcast(x, 0) y = addbroadcast(x, 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论