提交 c2909c93 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace Shape_i with any static shape value and not just 1

上级 1a530729
...@@ -1020,8 +1020,8 @@ def local_Shape_of_SpecifyShape(fgraph, node): ...@@ -1020,8 +1020,8 @@ def local_Shape_of_SpecifyShape(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@node_rewriter([Shape_i]) @node_rewriter([Shape_i])
def local_Shape_i_of_broadcastable(fgraph, node): def local_Shape_i_ground(fgraph, node):
"""Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``.""" """Replace ``shape_i(x, i)`` with ``s`` when ``x.type.shape[i] == s``."""
if not isinstance(node.op, Shape_i): if not isinstance(node.op, Shape_i):
return False return False
...@@ -1031,8 +1031,9 @@ def local_Shape_i_of_broadcastable(fgraph, node): ...@@ -1031,8 +1031,9 @@ def local_Shape_i_of_broadcastable(fgraph, node):
if not isinstance(shape_arg.type, TensorType): if not isinstance(shape_arg.type, TensorType):
return False return False
if shape_arg.broadcastable[node.op.i]: s_val = shape_arg.type.shape[node.op.i]
return [as_tensor_variable(1, dtype=np.int64)] if s_val is not None:
return [as_tensor_variable(s_val, dtype=np.int64)]
@register_specialize @register_specialize
......
...@@ -493,15 +493,15 @@ def test_local_Shape_of_SpecifyShape_partial(s1): ...@@ -493,15 +493,15 @@ def test_local_Shape_of_SpecifyShape_partial(s1):
assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes) assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
def test_local_Shape_i_of_broadcastable(): def test_local_Shape_i_ground():
x = tensor(np.float64, shape=(None, 1)) x = tensor(np.float64, shape=(None, 2))
s = Shape_i(1)(x) s = Shape_i(1)(x)
fgraph = FunctionGraph(outputs=[s], clone=False) fgraph = FunctionGraph(outputs=[s], clone=False)
_ = rewrite_graph(fgraph, clone=False) _ = rewrite_graph(fgraph, clone=False)
assert x not in fgraph.variables assert x not in fgraph.variables
assert fgraph.outputs[0].data == 1 assert fgraph.outputs[0].data == 2
# A test for a non-`TensorType` # A test for a non-`TensorType`
class MyType(Type): class MyType(Type):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论