提交 b4817c25 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix check for broadcastable pattern.

上级 aebaa276
......@@ -618,8 +618,8 @@ def local_upcast_elemwise_constant_inputs(node):
new_inputs.append(i)
else:
try:
cval_i = get_constant_value(i) # works only for scalars I think
if all((not b for b in i.broadcastable)):
cval_i = get_constant_value(i) # works only for scalars
if all(i.broadcastable):
new_inputs.append(T.cast(cval_i, output_dtype))
else:
if shape_i is None:
......@@ -638,9 +638,11 @@ def local_upcast_elemwise_constant_inputs(node):
if rval[0].type != node.outputs[0].type:
print >> sys.stderr, "NODE:", node
print >> sys.stderr, "NODE INPUT TYPES:", [i.type for i in node.inputs]
print >> sys.stderr, "NODE OUTPUT TYPES:", [o.type for o in node.outputs]
print >> sys.stderr, "RVAL:", rval
print >> sys.stderr, "NEW INPUT TYPES:", [i.type for i in new_inputs]
print >> sys.stderr, "RVAL INPUT TYPES:", [i.type for i in rval[0].owner.inputs]
print >> sys.stderr, "RVAL TYPES:", [o.type for o in rval]
assert rval[0].type == node.outputs[0].type, (node, rval[0])
return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论