提交 5bda83ef authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Frederic

get_scalar_constant_value() now know about Rebroadcast(x,...).shape[idx].

上级 53ccd442
...@@ -683,6 +683,14 @@ def get_scalar_constant_value(orig_v, elemwise=True): ...@@ -683,6 +683,14 @@ def get_scalar_constant_value(orig_v, elemwise=True):
grandparent = leftmost_parent.owner.inputs[0] grandparent = leftmost_parent.owner.inputs[0]
gp_broadcastable = grandparent.type.broadcastable gp_broadcastable = grandparent.type.broadcastable
ndim = grandparent.type.ndim ndim = grandparent.type.ndim
if grandparent.owner and isinstance(grandparent.owner.op,
Rebroadcast):
l = []
for idx, (b1, b2) in enumerate(
zip(grandparent.owner.inputs[0].broadcastable,
gp_broadcastable)):
l.append(b1 or b2)
gp_broadcastable = tuple(l)
assert ndim == len(gp_broadcastable) assert ndim == len(gp_broadcastable)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论