提交 499334e6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix a case in local_subtensor_lift where the optimization was failing.

A broadcasting pattern of '' was used, instead of [].
上级 6dacc650
...@@ -1105,7 +1105,7 @@ def local_subtensor_lift(node): ...@@ -1105,7 +1105,7 @@ def local_subtensor_lift(node):
if node.outputs[0].ndim == i.ndim: if node.outputs[0].ndim == i.ndim:
new_inputs.append(i) new_inputs.append(i)
else: else:
new_inputs.append(i.dimshuffle('x'*node.outputs[0].ndim)) new_inputs.append(i.dimshuffle(['x']*node.outputs[0].ndim))
return [u.owner.op(*new_inputs)] return [u.owner.op(*new_inputs)]
@register_canonicalize @register_canonicalize
......
...@@ -1244,6 +1244,22 @@ class test_local_subtensor_lift(unittest.TestCase): ...@@ -1244,6 +1244,22 @@ class test_local_subtensor_lift(unittest.TestCase):
assert len(prog)==4 assert len(prog)==4
f([[0,1],[2,3]], [4,5]) # let debugmode test something f([[0,1],[2,3]], [4,5]) # let debugmode test something
def test6(self):
# basic test that the optimization works with a scalar as input,
# and a scalar as output (no broadcasting of the scalar needed).
# The optimization used to fail and display an ERROR message.
x = TT.vector('x')
y = TT.scalar('y')
f = function([x,y], TT.exp(x+y)[0], mode=mode_opt)
prog=f.maker.env.toposort()
assert isinstance(prog[0].op, TT.Subtensor)
# Composite{add,exp}
assert isinstance(prog[1].op.scalar_op, theano.scalar.Composite)
assert len(prog)==2
f([1,2,3], 4) # let debugmode test something
class test_local_subtensor_merge(unittest.TestCase): class test_local_subtensor_merge(unittest.TestCase):
def test_const(self): def test_const(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论