Unverified 提交 cac9febf authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Add `OpFromGraph` wrapper around `alloc_diag` (#915)

* Add `OpFromGraph` wrapper around `alloc_diag` * Remove depreciated `AllocDiag` `Op`, rename `AllocDiag2 -> AllocDiag` * Set `inline = False` * Add rewrite to inline all `OpFromGraph` `Op`s * Add `is_zero_offset` helper to `Eye` * Add `is_left_expand_dims` and `is_right_expand_dims` attributes to `DimShuffle` * Seed `test_local_lift_through_linalg` test
上级 6ad1c5cf
...@@ -8,7 +8,6 @@ from typing import Union, cast ...@@ -8,7 +8,6 @@ from typing import Union, cast
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.mode import optdb
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, Rop, grad from pytensor.gradient import DisconnectedType, Rop, grad
...@@ -24,7 +23,6 @@ from pytensor.graph.fg import FunctionGraph ...@@ -24,7 +23,6 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
...@@ -575,7 +573,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -575,7 +573,7 @@ class OpFromGraph(Op, HasInnerGraph):
for inp_grad in input_grads for inp_grad in input_grads
if not isinstance(inp_grad.type, DisconnectedType | NullType) if not isinstance(inp_grad.type, DisconnectedType | NullType)
] ]
lop_op = type(self)( lop_op = OpFromGraph(
inputs=inner_inputs + connected_inner_outputs + connected_output_grads, inputs=inner_inputs + connected_inner_outputs + connected_output_grads,
outputs=connected_input_grads, outputs=connected_input_grads,
inline=self.is_inline, inline=self.is_inline,
...@@ -669,7 +667,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -669,7 +667,7 @@ class OpFromGraph(Op, HasInnerGraph):
for out_grad in output_grads for out_grad in output_grads
if not isinstance(out_grad.type, DisconnectedType | NullType) if not isinstance(out_grad.type, DisconnectedType | NullType)
] ]
rop_op = type(self)( rop_op = OpFromGraph(
inputs=inner_inputs + eval_points, inputs=inner_inputs + eval_points,
outputs=filtered_output_grads, outputs=filtered_output_grads,
inline=self.is_inline, inline=self.is_inline,
...@@ -852,29 +850,3 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -852,29 +850,3 @@ class OpFromGraph(Op, HasInnerGraph):
assert len(variables) == len(outputs) assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables): for output, variable in zip(outputs, variables):
output[0] = variable output[0] = variable
@node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node):
"""
This optimization expands internal graph of OpFromGraph.
Only performed if node.op.is_inline == True
Doing so can improve optimization at the cost of compilation speed.
"""
op = node.op
if not isinstance(op, OpFromGraph):
return False
if not op.is_inline:
return False
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
# We want to run this before the first merge optimizer
# and before the first scan optimizer.
optdb.register(
"inline_ofg_expansion",
in2out(inline_ofg_expansion),
"fast_compile",
"fast_run",
position=-0.01,
)
import warnings import warnings
from collections.abc import Callable
from functools import singledispatch from functools import singledispatch
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from pytensor.compile import JAX
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
...@@ -114,3 +117,24 @@ def jax_funcify_ViewOp(op, **kwargs): ...@@ -114,3 +117,24 @@ def jax_funcify_ViewOp(op, **kwargs):
return x return x
return viewop return viewop
@jax_funcify.register(OpFromGraph)
def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable:
_ = kwargs.pop("storage_map", None)
# Apply inner rewrites
JAX.optimizer(ofg.fgraph)
fgraph_fn = jax_funcify(ofg.fgraph, **kwargs)
if len(ofg.fgraph.outputs) == 1:
def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]
else:
def opfromgraph(*inputs):
return fgraph_fn(*inputs)
return opfromgraph
...@@ -21,6 +21,7 @@ import pytensor ...@@ -21,6 +21,7 @@ import pytensor
import pytensor.scalar.sharedvar import pytensor.scalar.sharedvar
from pytensor import compile, config, printing from pytensor import compile, config, printing
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
...@@ -1334,6 +1335,25 @@ class Eye(Op): ...@@ -1334,6 +1335,25 @@ class Eye(Op):
def grad(self, inp, grads): def grad(self, inp, grads):
return [grad_undefined(self, i, inp[i]) for i in range(3)] return [grad_undefined(self, i, inp[i]) for i in range(3)]
@staticmethod
def is_offset_zero(node) -> bool:
"""
Test if an Eye Op has a diagonal offset of zero
Parameters
----------
node
Eye node to test
Returns
-------
is_offset_zero: bool
True if the offset is zero (``k = 0``).
"""
offset = node.inputs[-1]
return isinstance(offset, Constant) and offset.data.item() == 0
def eye(n, m=None, k=0, dtype=None): def eye(n, m=None, k=0, dtype=None):
"""Return a 2-D array with ones on the diagonal and zeros elsewhere. """Return a 2-D array with ones on the diagonal and zeros elsewhere.
...@@ -3749,109 +3769,37 @@ def trace(a, offset=0, axis1=0, axis2=1): ...@@ -3749,109 +3769,37 @@ def trace(a, offset=0, axis1=0, axis2=1):
return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1) return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1)
class AllocDiag(Op): class AllocDiag(OpFromGraph):
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix.""" """
Wrapper Op for alloc_diag graphs
"""
__props__ = ("offset", "axis1", "axis2") __props__ = ("axis1", "axis2")
def __init__(self, offset=0, axis1=0, axis2=1): def __init__(self, *args, axis1, axis2, offset, **kwargs):
"""
Parameters
----------
offset: int
Offset of the diagonal from the main diagonal defined by `axis1`
and `axis2`. Can be positive or negative. Defaults to main
diagonal (i.e. 0).
axis1: int
Axis to be used as the first axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to first axis (i.e. 0).
axis2: int
Axis to be used as the second axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to second axis (i.e. 1).
"""
warnings.warn(
"AllocDiag is deprecated. Use `alloc_diag` instead",
FutureWarning,
)
self.offset = offset
if axis1 < 0 or axis2 < 0:
raise NotImplementedError("AllocDiag does not support negative axis")
if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same")
self.axis1 = axis1 self.axis1 = axis1
self.axis2 = axis2 self.axis2 = axis2
self.offset = offset
def make_node(self, diag): super().__init__(*args, **kwargs, strict=True)
diag = as_tensor_variable(diag)
if diag.type.ndim < 1:
raise ValueError(
"AllocDiag needs an input with 1 or more dimensions", diag.type
)
return Apply(
self,
[diag],
[diag.type.clone(shape=(None,) * (diag.ndim + 1))()],
)
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)
offset = self.offset
# Create array with one extra dimension for resulting matrix
result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2
result = np.zeros(result_shape, dtype=x.dtype)
# Create slice for diagonal in final 2 axes
idxs = np.arange(x.shape[-1])
diagonal_slice = (len(result_shape) - 2) * [slice(None)] + [
idxs + np.maximum(0, -offset),
idxs + np.maximum(0, offset),
]
# Fill in final 2 axes with x
result[tuple(diagonal_slice)] = x
if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = list(range(len(x.shape[:-1])))
last_idx = axes[-1]
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)
z[0] = result
def grad(self, inputs, gout):
(gz,) = gout
return [diagonal(gz, offset=self.offset, axis1=self.axis1, axis2=self.axis2)]
def infer_shape(self, fgraph, nodes, shapes):
(x_shape,) = shapes
axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)
result_shape = list(x_shape[:-1]) @staticmethod
diag_shape = x_shape[-1] + abs(self.offset) def is_offset_zero(node) -> bool:
result_shape = result_shape[:axis1] + [diag_shape] + result_shape[axis1:] """
result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:] Test if an AllocDiag Op has a diagonal offset of zero
return [tuple(result_shape)]
def __setstate__(self, state): Parameters
if "view_map" in state: ----------
del state["view_map"] node
AllocDiag node to test
self.__dict__.update(state) Returns
-------
is_offset_zero: bool
True if the offset is zero (``k = 0``).
"""
if "offset" not in state: return node.op.offset == 0
self.offset = 0
if "axis1" not in state:
self.axis1 = 0
if "axis2" not in state:
self.axis2 = 1
def alloc_diag(diag, offset=0, axis1=0, axis2=1): def alloc_diag(diag, offset=0, axis1=0, axis2=1):
...@@ -3862,6 +3810,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): ...@@ -3862,6 +3810,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
from pytensor.tensor import set_subtensor from pytensor.tensor import set_subtensor
diag = as_tensor_variable(diag) diag = as_tensor_variable(diag)
axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1) axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1)
if axis1 > axis2: if axis1 > axis2:
axis1, axis2 = axis2, axis1 axis1, axis2 = axis2, axis1
...@@ -3888,7 +3837,9 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): ...@@ -3888,7 +3837,9 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:] axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes) result = result.transpose(axes)
return result return AllocDiag(
inputs=[diag], outputs=[result], axis1=axis1, axis2=axis2, offset=offset
)(diag)
def diag(v, k=0): def diag(v, k=0):
......
...@@ -185,6 +185,14 @@ class DimShuffle(ExternalCOp): ...@@ -185,6 +185,14 @@ class DimShuffle(ExternalCOp):
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x") self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.drop = drop self.drop = drop
input_ndim = len(input_broadcastable)
self.is_left_expand_dims = self.augment and (
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
)
self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list(
range(input_ndim)
)
if self.inplace: if self.inplace:
self.view_map = {0: [0]} self.view_map = {0: [0]}
......
...@@ -10,6 +10,7 @@ import pytensor.tensor.rewriting.extra_ops ...@@ -10,6 +10,7 @@ import pytensor.tensor.rewriting.extra_ops
import pytensor.tensor.rewriting.jax import pytensor.tensor.rewriting.jax
import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.ofg
import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special import pytensor.tensor.rewriting.special
import pytensor.tensor.rewriting.subtensor import pytensor.tensor.rewriting.subtensor
......
...@@ -5,12 +5,16 @@ from typing import cast ...@@ -5,12 +5,16 @@ from typing import cast
from pytensor import Variable from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
PatternNodeRewriter,
copy_stack_trace, copy_stack_trace,
node_rewriter, node_rewriter,
) )
from pytensor.scalar.basic import Mul from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal from pytensor.tensor.basic import (
AllocDiag,
Eye,
TensorVariable,
diagonal,
)
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
...@@ -41,7 +45,6 @@ from pytensor.tensor.slinalg import ( ...@@ -41,7 +45,6 @@ from pytensor.tensor.slinalg import (
solve, solve,
solve_triangular, solve_triangular,
) )
from pytensor.tensor.subtensor import advanced_set_subtensor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -402,30 +405,68 @@ def _find_diag_from_eye_mul(potential_mul_input): ...@@ -402,30 +405,68 @@ def _find_diag_from_eye_mul(potential_mul_input):
eye_input = [ eye_input = [
mul_input mul_input
for mul_input in inputs_to_mul for mul_input in inputs_to_mul
if mul_input.owner and isinstance(mul_input.owner.op, Eye) if mul_input.owner
and (
isinstance(mul_input.owner.op, Eye)
or
# This whole condition checks if there is an Eye hiding inside a DimShuffle.
# This arises from batched elementwise multiplication between a tensor and an eye, e.g.:
# tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites.
(
isinstance(mul_input.owner.op, DimShuffle)
and (
mul_input.owner.op.is_left_expand_dims
or mul_input.owner.op.is_right_expand_dims
)
and mul_input.owner.inputs[0].owner is not None
and isinstance(mul_input.owner.inputs[0].owner.op, Eye)
)
)
] ]
# Check if 1's are being put on the main diagonal only (k = 0) if not eye_input:
if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0:
return None return None
# If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite eye_input = eye_input[0]
if eye_input and eye_input[0].broadcastable[-2:] != (False, False): # If eye_input is an Eye Op (it's not wrapped in a DimShuffle), check it doesn't have an offset
if isinstance(eye_input.owner.op, Eye) and (
not Eye.is_offset_zero(eye_input.owner)
or eye_input.broadcastable[-2:] != (False, False)
):
return None return None
# Otherwise, an Eye was found but it is wrapped in a DimShuffle (i.e. there was some broadcasting going on).
# We have to look inside DimShuffle to decide if the rewrite can be applied
if isinstance(eye_input.owner.op, DimShuffle) and (
eye_input.owner.op.is_left_expand_dims
or eye_input.owner.op.is_right_expand_dims
):
inner_eye = eye_input.owner.inputs[0]
# We can only rewrite when the Eye is on the main diagonal (the offset is zero) and the identity isn't
# degenerate
if not Eye.is_offset_zero(inner_eye.owner) or inner_eye.broadcastable[-2:] != (
False,
False,
):
return None
# Get all non Eye inputs (scalars/matrices/vectors) # Get all non Eye inputs (scalars/matrices/vectors)
non_eye_inputs = list(set(inputs_to_mul) - set(eye_input)) non_eye_inputs = list(set(inputs_to_mul) - {eye_input})
return eye_input, non_eye_inputs return eye_input, non_eye_inputs
@register_canonicalize("shape_unsafe") @register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe") @register_stabilize("shape_unsafe")
@node_rewriter([det]) @node_rewriter([det])
def rewrite_det_diag_from_eye_mul(fgraph, node): def rewrite_det_diag_to_prod_diag(fgraph, node):
""" """
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its
diagonal elements.
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix. The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices
that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to
make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar,
vector or a matrix.
Parameters Parameters
---------- ----------
...@@ -439,53 +480,45 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): ...@@ -439,53 +480,45 @@ def rewrite_det_diag_from_eye_mul(fgraph, node):
list of Variable, optional list of Variable, optional
List of optimized variables, or None if no optimization was performed List of optimized variables, or None if no optimization was performed
""" """
potential_mul_input = node.inputs[0] inputs = node.inputs[0]
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input)
if eye_non_eye_inputs is None: # Check for use of pt.diag first
if (
inputs.owner
and isinstance(inputs.owner.op, AllocDiag)
and AllocDiag.is_offset_zero(inputs.owner)
):
diag_input = inputs.owner.inputs[0]
det_val = diag_input.prod(axis=-1)
return [det_val]
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(inputs)
if inputs_or_none is None:
return None return None
eye_input, non_eye_inputs = eye_non_eye_inputs
eye_input, non_eye_inputs = inputs_or_none
# Dealing with only one other input # Dealing with only one other input
if len(non_eye_inputs) != 1: if len(non_eye_inputs) != 1:
return None return None
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0] eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
# Checking if original x was scalar/vector/matrix # Checking if original x was scalar/vector/matrix
if useful_non_eye.type.broadcastable[-2:] == (True, True): if non_eye_input.type.broadcastable[-2:] == (True, True):
# For scalar # For scalar
det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0]) det_val = non_eye_input.squeeze(axis=(-1, -2)) ** (eye_input.shape[0])
elif useful_non_eye.type.broadcastable[-2:] == (False, False): elif non_eye_input.type.broadcastable[-2:] == (False, False):
# For Matrix # For Matrix
det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1) det_val = non_eye_input.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
else: else:
# For vector # For vector
det_val = useful_non_eye.prod(axis=(-1, -2)) det_val = non_eye_input.prod(axis=(-1, -2))
det_val = det_val.astype(node.outputs[0].type.dtype) det_val = det_val.astype(node.outputs[0].type.dtype)
return [det_val] return [det_val]
arange = ARange("int64")
det_diag_from_diag = PatternNodeRewriter(
(
det,
(
advanced_set_subtensor,
(alloc, 0, "sh1", "sh2"),
"x",
(arange, 0, "stop", 1),
(arange, 0, "stop", 1),
),
),
(prod, "x"),
name="det_diag_from_diag",
allow_multiple_clients=True,
)
register_canonicalize(det_diag_from_diag)
register_stabilize(det_diag_from_diag)
register_specialize(det_diag_from_diag)
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
......
from pytensor import clone_replace
from pytensor.compile import optdb
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out
from pytensor.tensor.basic import AllocDiag
from pytensor.tensor.rewriting.basic import register_specialize
@node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node):
"""
This optimization expands internal graph of OpFromGraph.
Only performed if node.op.is_inline == True
Doing so can improve optimization at the cost of compilation speed.
"""
op = node.op
if not op.is_inline:
return False
new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
copy_stack_trace(op.inner_outputs, new_out)
return new_out
# We want to run this before the first merge optimizer
# and before the first scan optimizer.
optdb.register(
"inline_ofg_expansion",
in2out(inline_ofg_expansion),
"fast_compile",
"fast_run",
position=-0.01,
)
@register_specialize("inline_ofg")
@node_rewriter([AllocDiag])
def late_inline_OpFromGraph(fgraph, node):
"""
Inline `OpFromGraph` nodes.
OpFromGraph nodes are used to compactly represent the output of a function graph. Certain `Ops`, like, einsum,
diag, and kron, are implemented using pytensor `Op`s. As a result, their outputs are not a single `Op`, but a
graph. To allow rewrites to easily spot and manipulate these "composite functions", we use the `OpFromGraph` node.
This node is a thin wrapper around the output graph. It is not, however, meant to be included in the final
program, because it hides the inner graph from certain optimizations.
This rewrite specifies that all `OpFromGraph` nodes should be replaced by their inner graphs by setting the
`inplace=True` flag.
Parameters
----------
fgraph: FunctionGraph
The function graph being rewritten
node: Apply
Node of the function graph to be optimized
Returns
-------
"""
op = node.op
new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
copy_stack_trace(op.inner_outputs, new_out)
return new_out
...@@ -4,6 +4,7 @@ from functools import partial ...@@ -4,6 +4,7 @@ from functools import partial
import numpy as np import numpy as np
import pytest import pytest
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import get_mode from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import SharedVariable, shared
...@@ -13,7 +14,7 @@ from pytensor.graph.fg import FunctionGraph ...@@ -13,7 +14,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value from pytensor.graph.op import Op, get_test_value
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor.type import dscalar, scalar, vector from pytensor.tensor.type import dscalar, matrices, scalar, vector
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
...@@ -209,3 +210,19 @@ def test_jax_checkandraise(): ...@@ -209,3 +210,19 @@ def test_jax_checkandraise():
def set_test_value(x, v): def set_test_value(x, v):
x.tag.test_value = v x.tag.test_value = v
return x return x
def test_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)
o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
out_fg = FunctionGraph([x, y, z], [out])
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_jax_and_py(out_fg, [xv, yv, zv])
...@@ -362,6 +362,8 @@ class TestBatchedVectorBSolveToMatrixBSolve: ...@@ -362,6 +362,8 @@ class TestBatchedVectorBSolveToMatrixBSolve:
ids=["block_diag", "kron"], ids=["block_diag", "kron"],
) )
def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
rng = np.random.default_rng(sum(map(ord, "lift_through_linalg")))
if pytensor.config.floatX.endswith("32"): if pytensor.config.floatX.endswith("32"):
pytest.skip("Test is flaky at half precision") pytest.skip("Test is flaky at half precision")
...@@ -371,6 +373,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): ...@@ -371,6 +373,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
f1 = pytensor.function( f1 = pytensor.function(
[A, B], X, mode=get_default_mode().including("local_lift_through_linalg") [A, B], X, mode=get_default_mode().including("local_lift_through_linalg")
) )
f2 = pytensor.function( f2 = pytensor.function(
[A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg") [A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg")
) )
...@@ -386,9 +389,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): ...@@ -386,9 +389,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
assert len(f_ops) == 2 assert len(f_ops) == 2
assert len(g_ops) == 1 assert len(g_ops) == 1
test_vals = [ test_vals = [rng.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)]
np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)
]
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals] test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
...@@ -403,13 +404,18 @@ def test_det_diag_from_eye_mul(shape): ...@@ -403,13 +404,18 @@ def test_det_diag_from_eye_mul(shape):
# Initializing x based on scalar/vector/matrix # Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape) x = pt.tensor("x", shape=shape)
y = pt.eye(7) * x y = pt.eye(7) * x
# Calculating determinant value using pt.linalg.det # Calculating determinant value using pt.linalg.det
z_det = pt.linalg.det(y) z_det = pt.linalg.det(y)
# REWRITE TEST # REWRITE TEST
f_rewritten = function([x], z_det, mode="FAST_RUN") f_rewritten = function([x], z_det, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Det) for node in nodes)
assert not any(
isinstance(node.op, Det) or isinstance(getattr(node.op, "core_op", None), Det)
for node in nodes
)
# NUMERIC VALUE TEST # NUMERIC VALUE TEST
if len(shape) == 0: if len(shape) == 0:
...@@ -418,6 +424,7 @@ def test_det_diag_from_eye_mul(shape): ...@@ -418,6 +424,7 @@ def test_det_diag_from_eye_mul(shape):
x_test = np.random.rand(*shape).astype(config.floatX) x_test = np.random.rand(*shape).astype(config.floatX)
else: else:
x_test = np.random.rand(*shape).astype(config.floatX) x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test x_test_matrix = np.eye(7) * x_test
det_val = np.linalg.det(x_test_matrix) det_val = np.linalg.det(x_test_matrix)
rewritten_val = f_rewritten(x_test) rewritten_val = f_rewritten(x_test)
...@@ -459,6 +466,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1(): ...@@ -459,6 +466,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1():
x_diag = pt.eye(1, 1) * x x_diag = pt.eye(1, 1) * x
y = pt.linalg.det(x_diag) y = pt.linalg.det(x_diag)
f_rewritten = function([x], y, mode="FAST_RUN") f_rewritten = function([x], y, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Det) for node in nodes) assert any(isinstance(node.op, Det) for node in nodes)
...@@ -468,6 +476,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1(): ...@@ -468,6 +476,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1():
x_test_matrix = np.eye(1, 1) * x_test x_test_matrix = np.eye(1, 1) * x_test
det_val = np.linalg.det(x_test_matrix) det_val = np.linalg.det(x_test_matrix)
rewritten_val = f_rewritten(x_test) rewritten_val = f_rewritten(x_test)
assert_allclose( assert_allclose(
det_val, det_val,
rewritten_val, rewritten_val,
......
import pytest
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.compile.builders import OpFromGraph
@pytest.mark.skipif(
config.mode == "FAST_COMPILE",
reason="Rewrite is not applied in FAST_COMPILE mode",
)
def test_alloc_diag_inlined():
x = pt.tensor("x", shape=(None,))
z = pt.diag(x)
assert isinstance(z.owner.op, OpFromGraph)
f = pytensor.function([x], z)
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, OpFromGraph) for node in nodes)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论