提交 a0be97e8 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Canonicalize `SplitDims` and `JoinDims` to `Reshape`

上级 c3347b98
...@@ -10,6 +10,7 @@ import pytensor.tensor.rewriting.linalg ...@@ -10,6 +10,7 @@ import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.numba import pytensor.tensor.rewriting.numba
import pytensor.tensor.rewriting.ofg import pytensor.tensor.rewriting.ofg
import pytensor.tensor.rewriting.reshape
import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special import pytensor.tensor.rewriting.special
import pytensor.tensor.rewriting.subtensor import pytensor.tensor.rewriting.subtensor
......
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace
from pytensor.tensor.reshape import JoinDims, SplitDims
from pytensor.tensor.rewriting.basic import register_canonicalize
@register_canonicalize
@node_rewriter([SplitDims])
def local_split_dims_to_reshape(fgraph, node):
"""
Canonicalize SplitDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends).
"""
x, shape = node.inputs
axis = node.op.axis
output_shape = [
*x.shape[:axis],
*shape,
*x.shape[axis + 1 :],
]
new_x = x.reshape(output_shape)
copy_stack_trace(x, new_x)
return [new_x]
@register_canonicalize
@node_rewriter([JoinDims])
def local_join_dims_to_reshape(fgraph, node):
"""
Canonicalize JoinDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends).
"""
(x,) = node.inputs
start_axis = node.op.start_axis
n_axes = node.op.n_axes
output_shape = [
*x.shape[:start_axis],
-1,
*x.shape[start_axis + n_axes :],
]
new_x = x.reshape(output_shape)
copy_stack_trace(x, new_x)
return [new_x]
from pytensor.graph import FunctionGraph, rewrite_graph
from pytensor.tensor.reshape import JoinDims, SplitDims, join_dims, split_dims
from pytensor.tensor.shape import Reshape
from pytensor.tensor.type import tensor
def test_local_split_dims_to_reshape():
x = tensor("x", shape=(2, 10, 3))
x_split = split_dims(x, axis=1, shape=(2, 5, 1))
fg = FunctionGraph(inputs=[x], outputs=[x_split])
assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1
rewrite_graph(fg, include=("canonicalize",))
assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0
assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1
assert fg.outputs[0].type.shape == (2, 2, 5, 1, 3)
def test_local_join_dims_to_reshape():
x = tensor("x", shape=(2, 2, 5, 1, 3))
x_join = join_dims(x, axis=(1, 2, 3))
fg = FunctionGraph(inputs=[x], outputs=[x_join])
assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1
rewrite_graph(fg, include=("canonicalize",))
assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0
assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1
assert fg.outputs[0].type.shape == (2, 10, 3)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论