提交 979133cf authored 作者: Frederic Bastien's avatar Frederic Bastien

Elemwise.infer_shape work when their is a scalar in input

上级 c4596efb
...@@ -635,6 +635,8 @@ class Elemwise(Op): ...@@ -635,6 +635,8 @@ class Elemwise(Op):
b_dim = 1 b_dim = 1
else: # there must be some input that is not broadcastable else: # there must be some input that is not broadcastable
for ishp, i in zip(i_shapes,node.inputs): for ishp, i in zip(i_shapes,node.inputs):
if isinstance(i.type,theano.scalar.Scalar):
continue #we skip scalar
if not i.type.broadcastable[dim]: if not i.type.broadcastable[dim]:
b_dim = ishp[dim] b_dim = ishp[dim]
assert b_dim, 'AA' assert b_dim, 'AA'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论