提交 af7ed248 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster infer_static_shape

上级 189ba03a
...@@ -22,10 +22,11 @@ import pytensor.scalar.sharedvar ...@@ -22,10 +22,11 @@ import pytensor.scalar.sharedvar
from pytensor import compile, config, printing from pytensor import compile, config, printing
from pytensor import scalar as aes from pytensor import scalar as aes
from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
...@@ -1356,6 +1357,45 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None): ...@@ -1356,6 +1357,45 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype) return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)
class CachedEquilibrimDB(EquilibriumDB):
"""A subclass of EquilibriumDB that allows caching of a default query for faster reuse."""
def __init__(self, default_query):
super().__init__()
self._default_query = default_query
self._cached_default_query = None
def register(self, *args, **kwargs):
# If new rewrites are registered, the default cached query is void
self.cached_default_query = None
super().register(*args, **kwargs)
@property
def default_query(self):
if self._cached_default_query is None:
self._cached_default_query = self.query(self._default_query)
return self._cached_default_query
infer_shape_db = CachedEquilibrimDB(
default_query=RewriteDatabaseQuery(include=("infer_shape",))
)
def register_infer_shape(rewrite, *tags, **kwargs):
if isinstance(rewrite, str):
def register(inner_lopt):
return register_infer_shape(inner_lopt, rewrite, *tags, **kwargs)
return register
else:
name = kwargs.pop("name", None) or rewrite.__name__
infer_shape_db.register(name, rewrite, *tags, "infer_shape", **kwargs)
return rewrite
def infer_static_shape( def infer_static_shape(
shape: Union[Variable, Sequence[Union[Variable, int]]] shape: Union[Variable, Sequence[Union[Variable, int]]]
) -> tuple[Sequence["TensorLike"], Sequence[Optional[int]]]: ) -> tuple[Sequence["TensorLike"], Sequence[Optional[int]]]:
...@@ -1390,14 +1430,16 @@ def infer_static_shape( ...@@ -1390,14 +1430,16 @@ def infer_static_shape(
raise TypeError(f"Shapes must be scalar integers; got {s_as_str}") raise TypeError(f"Shapes must be scalar integers; got {s_as_str}")
sh = [check_type(as_tensor_variable(s, ndim=0)) for s in shape] sh = folded_shape = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]
if not all(isinstance(s, Constant) for s in folded_shape):
shape_fg = FunctionGraph(outputs=sh, features=[ShapeFeature()], clone=True)
with config.change_flags(optdb__max_use_ratio=10, cxx=""):
infer_shape_db.default_query.rewrite(shape_fg)
if not all(isinstance(s, Constant) for s in shape_fg.outputs):
topo_constant_folding.rewrite(shape_fg)
folded_shape = shape_fg.outputs
shape_fg = FunctionGraph(
outputs=sh,
features=[ShapeFeature()],
clone=True,
)
folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
static_shape = tuple( static_shape = tuple(
s.data.item() if isinstance(s, Constant) else None for s in folded_shape s.data.item() if isinstance(s, Constant) else None for s in folded_shape
) )
......
...@@ -58,6 +58,7 @@ from pytensor.tensor.basic import ( ...@@ -58,6 +58,7 @@ from pytensor.tensor.basic import (
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
join, join,
ones_like, ones_like,
register_infer_shape,
switch, switch,
tensor_copy, tensor_copy,
zeros, zeros,
...@@ -420,6 +421,7 @@ compile.optdb.register( ...@@ -420,6 +421,7 @@ compile.optdb.register(
) )
@register_infer_shape
@register_canonicalize("fast_compile", "shape_unsafe") @register_canonicalize("fast_compile", "shape_unsafe")
@register_useless("shape_unsafe") @register_useless("shape_unsafe")
@node_rewriter([fill]) @node_rewriter([fill])
...@@ -441,6 +443,7 @@ def local_useless_fill(fgraph, node): ...@@ -441,6 +443,7 @@ def local_useless_fill(fgraph, node):
return [v] return [v]
@register_infer_shape
@register_specialize("shape_unsafe") @register_specialize("shape_unsafe")
@register_stabilize("shape_unsafe") @register_stabilize("shape_unsafe")
@register_canonicalize("shape_unsafe") @register_canonicalize("shape_unsafe")
...@@ -530,6 +533,7 @@ compile.optdb.register( ...@@ -530,6 +533,7 @@ compile.optdb.register(
) )
@register_infer_shape
@register_useless @register_useless
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_specialize @register_specialize
...@@ -806,6 +810,7 @@ compile.optdb["useless"].register( ...@@ -806,6 +810,7 @@ compile.optdb["useless"].register(
) )
@register_infer_shape
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@register_useless @register_useless
...@@ -826,6 +831,7 @@ def local_join_1(fgraph, node): ...@@ -826,6 +831,7 @@ def local_join_1(fgraph, node):
# TODO: merge in local_useless_join # TODO: merge in local_useless_join
@register_infer_shape
@register_useless @register_useless
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
...@@ -1066,6 +1072,7 @@ def local_merge_switch_same_cond(fgraph, node): ...@@ -1066,6 +1072,7 @@ def local_merge_switch_same_cond(fgraph, node):
] ]
@register_infer_shape
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
...@@ -1149,6 +1156,7 @@ register_stabilize(topo_constant_folding, "fast_compile", final_rewriter=True) ...@@ -1149,6 +1156,7 @@ register_stabilize(topo_constant_folding, "fast_compile", final_rewriter=True)
register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True) register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)
@register_infer_shape
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_useless("fast_compile") @register_useless("fast_compile")
@node_rewriter(None) @node_rewriter(None)
...@@ -1157,6 +1165,7 @@ def local_view_op(fgraph, node): ...@@ -1157,6 +1165,7 @@ def local_view_op(fgraph, node):
return node.inputs return node.inputs
@register_infer_shape
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
......
...@@ -32,6 +32,7 @@ from pytensor.tensor.basic import ( ...@@ -32,6 +32,7 @@ from pytensor.tensor.basic import (
extract_constant, extract_constant,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
ones_like, ones_like,
register_infer_shape,
switch, switch,
zeros_like, zeros_like,
) )
...@@ -1745,6 +1746,7 @@ def local_reduce_join(fgraph, node): ...@@ -1745,6 +1746,7 @@ def local_reduce_join(fgraph, node):
return [ret] return [ret]
@register_infer_shape
@register_canonicalize("fast_compile", "local_cut_useless_reduce") @register_canonicalize("fast_compile", "local_cut_useless_reduce")
@register_useless("local_cut_useless_reduce") @register_useless("local_cut_useless_reduce")
@node_rewriter(ALL_REDUCE) @node_rewriter(ALL_REDUCE)
......
...@@ -25,6 +25,7 @@ from pytensor.tensor.basic import ( ...@@ -25,6 +25,7 @@ from pytensor.tensor.basic import (
constant, constant,
extract_constant, extract_constant,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
register_infer_shape,
stack, stack,
) )
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
...@@ -964,6 +965,7 @@ def local_reshape_lift(fgraph, node): ...@@ -964,6 +965,7 @@ def local_reshape_lift(fgraph, node):
return [e] return [e]
@register_infer_shape
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@node_rewriter([SpecifyShape]) @node_rewriter([SpecifyShape])
...@@ -990,6 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node): ...@@ -990,6 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node):
return [specify_shape(inner_obj, shape)] return [specify_shape(inner_obj, shape)]
@register_infer_shape
@node_rewriter([Shape])
def local_shape_ground(fgraph, node):
"""Rewrite shape(x) -> make_vector(x.type.shape) when this is constant."""
[x] = node.inputs
static_shape = x.type.shape
if not any(dim is None for dim in static_shape):
return [stack([constant(dim, dtype="int64") for dim in static_shape])]
@register_infer_shape
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@node_rewriter([Shape]) @node_rewriter([Shape])
...@@ -1014,6 +1027,7 @@ def local_Shape_of_SpecifyShape(fgraph, node): ...@@ -1014,6 +1027,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
return [stack(shape).astype(np.int64)] return [stack(shape).astype(np.int64)]
@register_infer_shape
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([SpecifyShape]) @node_rewriter([SpecifyShape])
...@@ -1060,6 +1074,7 @@ def local_specify_shape_lift(fgraph, node): ...@@ -1060,6 +1074,7 @@ def local_specify_shape_lift(fgraph, node):
return new_out return new_out
@register_infer_shape
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@node_rewriter([Shape_i]) @node_rewriter([Shape_i])
...@@ -1079,6 +1094,7 @@ def local_Shape_i_ground(fgraph, node): ...@@ -1079,6 +1094,7 @@ def local_Shape_i_ground(fgraph, node):
return [as_tensor_variable(s_val, dtype=np.int64)] return [as_tensor_variable(s_val, dtype=np.int64)]
@register_infer_shape
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@node_rewriter([Shape]) @node_rewriter([Shape])
......
...@@ -26,6 +26,7 @@ from pytensor.tensor.basic import ( ...@@ -26,6 +26,7 @@ from pytensor.tensor.basic import (
concatenate, concatenate,
extract_constant, extract_constant,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
register_infer_shape,
switch, switch,
) )
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
...@@ -328,6 +329,7 @@ def local_subtensor_of_dot(fgraph, node): ...@@ -328,6 +329,7 @@ def local_subtensor_of_dot(fgraph, node):
return [r] return [r]
@register_infer_shape
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
...@@ -599,6 +601,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): ...@@ -599,6 +601,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
return [node.inputs[0].dimshuffle(tuple(remain_dim))] return [node.inputs[0].dimshuffle(tuple(remain_dim))]
@register_infer_shape
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
...@@ -707,6 +710,7 @@ def local_subtensor_inc_subtensor(fgraph, node): ...@@ -707,6 +710,7 @@ def local_subtensor_inc_subtensor(fgraph, node):
return return
@register_infer_shape
@register_specialize @register_specialize
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_useless @register_useless
...@@ -785,6 +789,7 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -785,6 +789,7 @@ def local_subtensor_make_vector(fgraph, node):
pass pass
@register_infer_shape
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
...@@ -1461,6 +1466,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node): ...@@ -1461,6 +1466,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
return [r2] return [r2]
@register_infer_shape
@register_specialize @register_specialize
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
......
import warnings import warnings
from numbers import Number from numbers import Number
from textwrap import dedent from textwrap import dedent
from typing import Union from typing import Union, cast
import numpy as np import numpy as np
import pytensor import pytensor
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph import Op
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape
...@@ -145,14 +146,14 @@ _shape = Shape() ...@@ -145,14 +146,14 @@ _shape = Shape()
def shape(x: Union[np.ndarray, Number, Variable]) -> Variable: def shape(x: Union[np.ndarray, Number, Variable]) -> Variable:
"""Return the shape of `x`.""" """Return the shape of `x`."""
if not isinstance(x, Variable): if not isinstance(x, Variable):
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x) # type: ignore
return _shape(x) return cast(Variable, _shape(x))
@_get_vector_length.register(Shape) @_get_vector_length.register(Shape) # type: ignore
def _get_vector_length_Shape(op, var): def _get_vector_length_Shape(op: Op, var: TensorVariable) -> int:
return var.owner.inputs[0].type.ndim return cast(int, var.owner.inputs[0].type.ndim)
@_vectorize_node.register(Shape) @_vectorize_node.register(Shape)
...@@ -181,7 +182,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]: ...@@ -181,7 +182,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
# We assume/call it a scalar # We assume/call it a scalar
return () return ()
res = () res: tuple[Variable, ...] = ()
symbolic_shape = shape(x) symbolic_shape = shape(x)
static_shape = x.type.shape static_shape = x.type.shape
for i in range(x.type.ndim): for i in range(x.type.ndim):
...@@ -191,7 +192,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]: ...@@ -191,7 +192,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
# TODO: Why not use uint64? # TODO: Why not use uint64?
res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),) res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),)
else: else:
res += (symbolic_shape[i],) res += (symbolic_shape[i],) # type: ignore
return res return res
...@@ -366,7 +367,7 @@ def shape_i_op(i): ...@@ -366,7 +367,7 @@ def shape_i_op(i):
return shape_i_op.cache[key] return shape_i_op.cache[key]
shape_i_op.cache = {} shape_i_op.cache = {} # type: ignore
def register_shape_i_c_code(typ, code, check_input, version=()): def register_shape_i_c_code(typ, code, check_input, version=()):
...@@ -578,7 +579,7 @@ def specify_shape( ...@@ -578,7 +579,7 @@ def specify_shape(
# If the specified shape is already encoded in the input static shape, do nothing # If the specified shape is already encoded in the input static shape, do nothing
# This ignores PyTensor constants in shape # This ignores PyTensor constants in shape
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x) # type: ignore
new_shape_info = any( new_shape_info = any(
s != xts for (s, xts) in zip(shape, x.type.shape) if s is not None s != xts for (s, xts) in zip(shape, x.type.shape) if s is not None
) )
...@@ -589,10 +590,10 @@ def specify_shape( ...@@ -589,10 +590,10 @@ def specify_shape(
return _specify_shape(x, *shape) return _specify_shape(x, *shape)
@_get_vector_length.register(SpecifyShape) @_get_vector_length.register(SpecifyShape) # type: ignore
def _get_vector_length_SpecifyShape(op, var): def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int:
try: try:
return at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item() return int(at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item())
except NotScalarConstantError: except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined") raise ValueError(f"Length of {var} cannot be determined")
...@@ -1104,4 +1105,4 @@ def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> A ...@@ -1104,4 +1105,4 @@ def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> A
batched_ndims = x.type.ndim - node.inputs[0].type.ndim batched_ndims = x.type.ndim - node.inputs[0].type.ndim
old_axes = op.axes old_axes = op.axes
new_axes = (old_axis + batched_ndims for old_axis in old_axes) new_axes = (old_axis + batched_ndims for old_axis in old_axes)
return unbroadcast(x, *new_axes).owner return cast(Apply, unbroadcast(x, *new_axes).owner)
...@@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py ...@@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py
pytensor/tensor/random/op.py pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py pytensor/tensor/rewriting/basic.py
pytensor/tensor/shape.py
pytensor/tensor/slinalg.py pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py pytensor/tensor/subtensor.py
pytensor/tensor/type.py pytensor/tensor/type.py
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论