提交 a0be97e8 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Canonicalize `SplitDims` and `JoinDims` to `Reshape`

上级 c3347b98
......@@ -10,6 +10,7 @@ import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.numba
import pytensor.tensor.rewriting.ofg
import pytensor.tensor.rewriting.reshape
import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special
import pytensor.tensor.rewriting.subtensor
......
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace
from pytensor.tensor.reshape import JoinDims, SplitDims
from pytensor.tensor.rewriting.basic import register_canonicalize
@register_canonicalize
@node_rewriter([SplitDims])
def local_split_dims_to_reshape(fgraph, node):
"""
Canonicalize SplitDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends).
"""
x, shape = node.inputs
axis = node.op.axis
output_shape = [
*x.shape[:axis],
*shape,
*x.shape[axis + 1 :],
]
new_x = x.reshape(output_shape)
copy_stack_trace(x, new_x)
return [new_x]
@register_canonicalize
@node_rewriter([JoinDims])
def local_join_dims_to_reshape(fgraph, node):
"""
Canonicalize JoinDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends).
"""
(x,) = node.inputs
start_axis = node.op.start_axis
n_axes = node.op.n_axes
output_shape = [
*x.shape[:start_axis],
-1,
*x.shape[start_axis + n_axes :],
]
new_x = x.reshape(output_shape)
copy_stack_trace(x, new_x)
return [new_x]
from pytensor.graph import FunctionGraph, rewrite_graph
from pytensor.tensor.reshape import JoinDims, SplitDims, join_dims, split_dims
from pytensor.tensor.shape import Reshape
from pytensor.tensor.type import tensor
def test_local_split_dims_to_reshape():
x = tensor("x", shape=(2, 10, 3))
x_split = split_dims(x, axis=1, shape=(2, 5, 1))
fg = FunctionGraph(inputs=[x], outputs=[x_split])
assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1
rewrite_graph(fg, include=("canonicalize",))
assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0
assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1
assert fg.outputs[0].type.shape == (2, 2, 5, 1, 3)
def test_local_join_dims_to_reshape():
x = tensor("x", shape=(2, 2, 5, 1, 3))
x_join = join_dims(x, axis=(1, 2, 3))
fg = FunctionGraph(inputs=[x], outputs=[x_join])
assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1
rewrite_graph(fg, include=("canonicalize",))
assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0
assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1
assert fg.outputs[0].type.shape == (2, 10, 3)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论