提交 af7ed248 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster infer_static_shape

上级 189ba03a
......@@ -22,10 +22,11 @@ import pytensor.scalar.sharedvar
from pytensor import compile, config, printing
from pytensor import scalar as aes
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
......@@ -1356,6 +1357,45 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)
class CachedEquilibrimDB(EquilibriumDB):
"""A subclass of EquilibriumDB that allows caching of a default query for faster reuse."""
def __init__(self, default_query):
super().__init__()
self._default_query = default_query
self._cached_default_query = None
def register(self, *args, **kwargs):
# If new rewrites are registered, the default cached query is void
self.cached_default_query = None
super().register(*args, **kwargs)
@property
def default_query(self):
if self._cached_default_query is None:
self._cached_default_query = self.query(self._default_query)
return self._cached_default_query
infer_shape_db = CachedEquilibrimDB(
default_query=RewriteDatabaseQuery(include=("infer_shape",))
)
def register_infer_shape(rewrite, *tags, **kwargs):
if isinstance(rewrite, str):
def register(inner_lopt):
return register_infer_shape(inner_lopt, rewrite, *tags, **kwargs)
return register
else:
name = kwargs.pop("name", None) or rewrite.__name__
infer_shape_db.register(name, rewrite, *tags, "infer_shape", **kwargs)
return rewrite
def infer_static_shape(
shape: Union[Variable, Sequence[Union[Variable, int]]]
) -> tuple[Sequence["TensorLike"], Sequence[Optional[int]]]:
......@@ -1390,14 +1430,16 @@ def infer_static_shape(
raise TypeError(f"Shapes must be scalar integers; got {s_as_str}")
sh = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]
sh = folded_shape = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]
if not all(isinstance(s, Constant) for s in folded_shape):
shape_fg = FunctionGraph(outputs=sh, features=[ShapeFeature()], clone=True)
with config.change_flags(optdb__max_use_ratio=10, cxx=""):
infer_shape_db.default_query.rewrite(shape_fg)
if not all(isinstance(s, Constant) for s in shape_fg.outputs):
topo_constant_folding.rewrite(shape_fg)
folded_shape = shape_fg.outputs
shape_fg = FunctionGraph(
outputs=sh,
features=[ShapeFeature()],
clone=True,
)
folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
static_shape = tuple(
s.data.item() if isinstance(s, Constant) else None for s in folded_shape
)
......
......@@ -58,6 +58,7 @@ from pytensor.tensor.basic import (
get_underlying_scalar_constant_value,
join,
ones_like,
register_infer_shape,
switch,
tensor_copy,
zeros,
......@@ -420,6 +421,7 @@ compile.optdb.register(
)
@register_infer_shape
@register_canonicalize("fast_compile", "shape_unsafe")
@register_useless("shape_unsafe")
@node_rewriter([fill])
......@@ -441,6 +443,7 @@ def local_useless_fill(fgraph, node):
return [v]
@register_infer_shape
@register_specialize("shape_unsafe")
@register_stabilize("shape_unsafe")
@register_canonicalize("shape_unsafe")
......@@ -530,6 +533,7 @@ compile.optdb.register(
)
@register_infer_shape
@register_useless
@register_canonicalize("fast_compile")
@register_specialize
......@@ -806,6 +810,7 @@ compile.optdb["useless"].register(
)
@register_infer_shape
@register_specialize
@register_canonicalize
@register_useless
......@@ -826,6 +831,7 @@ def local_join_1(fgraph, node):
# TODO: merge in local_useless_join
@register_infer_shape
@register_useless
@register_specialize
@register_canonicalize
......@@ -1066,6 +1072,7 @@ def local_merge_switch_same_cond(fgraph, node):
]
@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
......@@ -1149,6 +1156,7 @@ register_stabilize(topo_constant_folding, "fast_compile", final_rewriter=True)
register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)
@register_infer_shape
@register_canonicalize("fast_compile")
@register_useless("fast_compile")
@node_rewriter(None)
......@@ -1157,6 +1165,7 @@ def local_view_op(fgraph, node):
return node.inputs
@register_infer_shape
@register_useless
@register_canonicalize
@register_stabilize
......
......@@ -32,6 +32,7 @@ from pytensor.tensor.basic import (
extract_constant,
get_underlying_scalar_constant_value,
ones_like,
register_infer_shape,
switch,
zeros_like,
)
......@@ -1745,6 +1746,7 @@ def local_reduce_join(fgraph, node):
return [ret]
@register_infer_shape
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
@register_useless("local_cut_useless_reduce")
@node_rewriter(ALL_REDUCE)
......
......@@ -25,6 +25,7 @@ from pytensor.tensor.basic import (
constant,
extract_constant,
get_underlying_scalar_constant_value,
register_infer_shape,
stack,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
......@@ -964,6 +965,7 @@ def local_reshape_lift(fgraph, node):
return [e]
@register_infer_shape
@register_useless
@register_canonicalize
@node_rewriter([SpecifyShape])
......@@ -990,6 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node):
return [specify_shape(inner_obj, shape)]
@register_infer_shape
@node_rewriter([Shape])
def local_shape_ground(fgraph, node):
"""Rewrite shape(x) -> make_vector(x.type.shape) when this is constant."""
[x] = node.inputs
static_shape = x.type.shape
if not any(dim is None for dim in static_shape):
return [stack([constant(dim, dtype="int64") for dim in static_shape])]
@register_infer_shape
@register_useless
@register_canonicalize
@node_rewriter([Shape])
......@@ -1014,6 +1027,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
return [stack(shape).astype(np.int64)]
@register_infer_shape
@register_canonicalize
@register_specialize
@node_rewriter([SpecifyShape])
......@@ -1060,6 +1074,7 @@ def local_specify_shape_lift(fgraph, node):
return new_out
@register_infer_shape
@register_useless
@register_canonicalize
@node_rewriter([Shape_i])
......@@ -1079,6 +1094,7 @@ def local_Shape_i_ground(fgraph, node):
return [as_tensor_variable(s_val, dtype=np.int64)]
@register_infer_shape
@register_specialize
@register_canonicalize
@node_rewriter([Shape])
......
......@@ -26,6 +26,7 @@ from pytensor.tensor.basic import (
concatenate,
extract_constant,
get_underlying_scalar_constant_value,
register_infer_shape,
switch,
)
from pytensor.tensor.elemwise import Elemwise
......@@ -328,6 +329,7 @@ def local_subtensor_of_dot(fgraph, node):
return [r]
@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
......@@ -599,6 +601,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
return [node.inputs[0].dimshuffle(tuple(remain_dim))]
@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
......@@ -707,6 +710,7 @@ def local_subtensor_inc_subtensor(fgraph, node):
return
@register_infer_shape
@register_specialize
@register_canonicalize("fast_compile")
@register_useless
......@@ -785,6 +789,7 @@ def local_subtensor_make_vector(fgraph, node):
pass
@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
......@@ -1461,6 +1466,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
return [r2]
@register_infer_shape
@register_specialize
@register_stabilize
@register_canonicalize
......
import warnings
from numbers import Number
from textwrap import dedent
from typing import Union
from typing import Union, cast
import numpy as np
import pytensor
from pytensor.gradient import DisconnectedType
from pytensor.graph import Op
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.type import HasShape
......@@ -145,14 +146,14 @@ _shape = Shape()
def shape(x: Union[np.ndarray, Number, Variable]) -> Variable:
"""Return the shape of `x`."""
if not isinstance(x, Variable):
x = at.as_tensor_variable(x)
x = at.as_tensor_variable(x) # type: ignore
return _shape(x)
return cast(Variable, _shape(x))
@_get_vector_length.register(Shape)
def _get_vector_length_Shape(op, var):
return var.owner.inputs[0].type.ndim
@_get_vector_length.register(Shape) # type: ignore
def _get_vector_length_Shape(op: Op, var: TensorVariable) -> int:
return cast(int, var.owner.inputs[0].type.ndim)
@_vectorize_node.register(Shape)
......@@ -181,7 +182,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
# We assume/call it a scalar
return ()
res = ()
res: tuple[Variable, ...] = ()
symbolic_shape = shape(x)
static_shape = x.type.shape
for i in range(x.type.ndim):
......@@ -191,7 +192,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
# TODO: Why not use uint64?
res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),)
else:
res += (symbolic_shape[i],)
res += (symbolic_shape[i],) # type: ignore
return res
......@@ -366,7 +367,7 @@ def shape_i_op(i):
return shape_i_op.cache[key]
shape_i_op.cache = {}
shape_i_op.cache = {} # type: ignore
def register_shape_i_c_code(typ, code, check_input, version=()):
......@@ -578,7 +579,7 @@ def specify_shape(
# If the specified shape is already encoded in the input static shape, do nothing
# This ignores PyTensor constants in shape
x = at.as_tensor_variable(x)
x = at.as_tensor_variable(x) # type: ignore
new_shape_info = any(
s != xts for (s, xts) in zip(shape, x.type.shape) if s is not None
)
......@@ -589,10 +590,10 @@ def specify_shape(
return _specify_shape(x, *shape)
@_get_vector_length.register(SpecifyShape)
def _get_vector_length_SpecifyShape(op, var):
@_get_vector_length.register(SpecifyShape) # type: ignore
def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int:
try:
return at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item()
return int(at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item())
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")
......@@ -1104,4 +1105,4 @@ def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> A
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
old_axes = op.axes
new_axes = (old_axis + batched_ndims for old_axis in old_axes)
return unbroadcast(x, *new_axes).owner
return cast(Apply, unbroadcast(x, *new_axes).owner)
......@@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py
pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py
pytensor/tensor/shape.py
pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py
pytensor/tensor/type.py
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论