Unverified 提交 cac9febf authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Add `OpFromGraph` wrapper around `alloc_diag` (#915)

* Add `OpFromGraph` wrapper around `alloc_diag` * Remove depreciated `AllocDiag` `Op`, rename `AllocDiag2 -> AllocDiag` * Set `inline = False` * Add rewrite to inline all `OpFromGraph` `Op`s * Add `is_zero_offset` helper to `Eye` * Add `is_left_expand_dims` and `is_right_expand_dims` attributes to `DimShuffle` * Seed `test_local_lift_through_linalg` test
上级 6ad1c5cf
......@@ -8,7 +8,6 @@ from typing import Union, cast
from pytensor.compile.function import function
from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.mode import optdb
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, Rop, grad
......@@ -24,7 +23,6 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError
......@@ -575,7 +573,7 @@ class OpFromGraph(Op, HasInnerGraph):
for inp_grad in input_grads
if not isinstance(inp_grad.type, DisconnectedType | NullType)
]
lop_op = type(self)(
lop_op = OpFromGraph(
inputs=inner_inputs + connected_inner_outputs + connected_output_grads,
outputs=connected_input_grads,
inline=self.is_inline,
......@@ -669,7 +667,7 @@ class OpFromGraph(Op, HasInnerGraph):
for out_grad in output_grads
if not isinstance(out_grad.type, DisconnectedType | NullType)
]
rop_op = type(self)(
rop_op = OpFromGraph(
inputs=inner_inputs + eval_points,
outputs=filtered_output_grads,
inline=self.is_inline,
......@@ -852,29 +850,3 @@ class OpFromGraph(Op, HasInnerGraph):
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables):
output[0] = variable
@node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node):
"""
This optimization expands internal graph of OpFromGraph.
Only performed if node.op.is_inline == True
Doing so can improve optimization at the cost of compilation speed.
"""
op = node.op
if not isinstance(op, OpFromGraph):
return False
if not op.is_inline:
return False
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
# We want to run this before the first merge optimizer
# and before the first scan optimizer.
optdb.register(
"inline_ofg_expansion",
in2out(inline_ofg_expansion),
"fast_compile",
"fast_run",
position=-0.01,
)
import warnings
from collections.abc import Callable
from functools import singledispatch
import jax
import jax.numpy as jnp
import numpy as np
from pytensor.compile import JAX
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
......@@ -114,3 +117,24 @@ def jax_funcify_ViewOp(op, **kwargs):
return x
return viewop
@jax_funcify.register(OpFromGraph)
def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable:
_ = kwargs.pop("storage_map", None)
# Apply inner rewrites
JAX.optimizer(ofg.fgraph)
fgraph_fn = jax_funcify(ofg.fgraph, **kwargs)
if len(ofg.fgraph.outputs) == 1:
def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]
else:
def opfromgraph(*inputs):
return fgraph_fn(*inputs)
return opfromgraph
......@@ -21,6 +21,7 @@ import pytensor
import pytensor.scalar.sharedvar
from pytensor import compile, config, printing
from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
......@@ -1334,6 +1335,25 @@ class Eye(Op):
def grad(self, inp, grads):
return [grad_undefined(self, i, inp[i]) for i in range(3)]
@staticmethod
def is_offset_zero(node) -> bool:
"""
Test if an Eye Op has a diagonal offset of zero
Parameters
----------
node
Eye node to test
Returns
-------
is_offset_zero: bool
True if the offset is zero (``k = 0``).
"""
offset = node.inputs[-1]
return isinstance(offset, Constant) and offset.data.item() == 0
def eye(n, m=None, k=0, dtype=None):
"""Return a 2-D array with ones on the diagonal and zeros elsewhere.
......@@ -3749,109 +3769,37 @@ def trace(a, offset=0, axis1=0, axis2=1):
return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1)
class AllocDiag(Op):
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
class AllocDiag(OpFromGraph):
"""
Wrapper Op for alloc_diag graphs
"""
__props__ = ("offset", "axis1", "axis2")
__props__ = ("axis1", "axis2")
def __init__(self, offset=0, axis1=0, axis2=1):
"""
Parameters
----------
offset: int
Offset of the diagonal from the main diagonal defined by `axis1`
and `axis2`. Can be positive or negative. Defaults to main
diagonal (i.e. 0).
axis1: int
Axis to be used as the first axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to first axis (i.e. 0).
axis2: int
Axis to be used as the second axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to second axis (i.e. 1).
"""
warnings.warn(
"AllocDiag is deprecated. Use `alloc_diag` instead",
FutureWarning,
)
self.offset = offset
if axis1 < 0 or axis2 < 0:
raise NotImplementedError("AllocDiag does not support negative axis")
if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same")
def __init__(self, *args, axis1, axis2, offset, **kwargs):
self.axis1 = axis1
self.axis2 = axis2
self.offset = offset
def make_node(self, diag):
diag = as_tensor_variable(diag)
if diag.type.ndim < 1:
raise ValueError(
"AllocDiag needs an input with 1 or more dimensions", diag.type
)
return Apply(
self,
[diag],
[diag.type.clone(shape=(None,) * (diag.ndim + 1))()],
)
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)
offset = self.offset
# Create array with one extra dimension for resulting matrix
result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2
result = np.zeros(result_shape, dtype=x.dtype)
# Create slice for diagonal in final 2 axes
idxs = np.arange(x.shape[-1])
diagonal_slice = (len(result_shape) - 2) * [slice(None)] + [
idxs + np.maximum(0, -offset),
idxs + np.maximum(0, offset),
]
# Fill in final 2 axes with x
result[tuple(diagonal_slice)] = x
if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = list(range(len(x.shape[:-1])))
last_idx = axes[-1]
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)
z[0] = result
def grad(self, inputs, gout):
(gz,) = gout
return [diagonal(gz, offset=self.offset, axis1=self.axis1, axis2=self.axis2)]
def infer_shape(self, fgraph, nodes, shapes):
(x_shape,) = shapes
axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)
super().__init__(*args, **kwargs, strict=True)
result_shape = list(x_shape[:-1])
diag_shape = x_shape[-1] + abs(self.offset)
result_shape = result_shape[:axis1] + [diag_shape] + result_shape[axis1:]
result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:]
return [tuple(result_shape)]
@staticmethod
def is_offset_zero(node) -> bool:
"""
Test if an AllocDiag Op has a diagonal offset of zero
def __setstate__(self, state):
if "view_map" in state:
del state["view_map"]
Parameters
----------
node
AllocDiag node to test
self.__dict__.update(state)
Returns
-------
is_offset_zero: bool
True if the offset is zero (``k = 0``).
"""
if "offset" not in state:
self.offset = 0
if "axis1" not in state:
self.axis1 = 0
if "axis2" not in state:
self.axis2 = 1
return node.op.offset == 0
def alloc_diag(diag, offset=0, axis1=0, axis2=1):
......@@ -3862,6 +3810,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
from pytensor.tensor import set_subtensor
diag = as_tensor_variable(diag)
axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1)
if axis1 > axis2:
axis1, axis2 = axis2, axis1
......@@ -3888,7 +3837,9 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)
return result
return AllocDiag(
inputs=[diag], outputs=[result], axis1=axis1, axis2=axis2, offset=offset
)(diag)
def diag(v, k=0):
......
......@@ -185,6 +185,14 @@ class DimShuffle(ExternalCOp):
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.drop = drop
input_ndim = len(input_broadcastable)
self.is_left_expand_dims = self.augment and (
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
)
self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list(
range(input_ndim)
)
if self.inplace:
self.view_map = {0: [0]}
......
......@@ -10,6 +10,7 @@ import pytensor.tensor.rewriting.extra_ops
import pytensor.tensor.rewriting.jax
import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.ofg
import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special
import pytensor.tensor.rewriting.subtensor
......
......@@ -5,12 +5,16 @@ from typing import cast
from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
PatternNodeRewriter,
copy_stack_trace,
node_rewriter,
)
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal
from pytensor.tensor.basic import (
AllocDiag,
Eye,
TensorVariable,
diagonal,
)
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
......@@ -41,7 +45,6 @@ from pytensor.tensor.slinalg import (
solve,
solve_triangular,
)
from pytensor.tensor.subtensor import advanced_set_subtensor
logger = logging.getLogger(__name__)
......@@ -402,30 +405,68 @@ def _find_diag_from_eye_mul(potential_mul_input):
eye_input = [
mul_input
for mul_input in inputs_to_mul
if mul_input.owner and isinstance(mul_input.owner.op, Eye)
if mul_input.owner
and (
isinstance(mul_input.owner.op, Eye)
or
# This whole condition checks if there is an Eye hiding inside a DimShuffle.
# This arises from batched elementwise multiplication between a tensor and an eye, e.g.:
# tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites.
(
isinstance(mul_input.owner.op, DimShuffle)
and (
mul_input.owner.op.is_left_expand_dims
or mul_input.owner.op.is_right_expand_dims
)
and mul_input.owner.inputs[0].owner is not None
and isinstance(mul_input.owner.inputs[0].owner.op, Eye)
)
)
]
# Check if 1's are being put on the main diagonal only (k = 0)
if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0:
if not eye_input:
return None
# If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite
if eye_input and eye_input[0].broadcastable[-2:] != (False, False):
eye_input = eye_input[0]
# If eye_input is an Eye Op (it's not wrapped in a DimShuffle), check it doesn't have an offset
if isinstance(eye_input.owner.op, Eye) and (
not Eye.is_offset_zero(eye_input.owner)
or eye_input.broadcastable[-2:] != (False, False)
):
return None
# Otherwise, an Eye was found but it is wrapped in a DimShuffle (i.e. there was some broadcasting going on).
# We have to look inside DimShuffle to decide if the rewrite can be applied
if isinstance(eye_input.owner.op, DimShuffle) and (
eye_input.owner.op.is_left_expand_dims
or eye_input.owner.op.is_right_expand_dims
):
inner_eye = eye_input.owner.inputs[0]
# We can only rewrite when the Eye is on the main diagonal (the offset is zero) and the identity isn't
# degenerate
if not Eye.is_offset_zero(inner_eye.owner) or inner_eye.broadcastable[-2:] != (
False,
False,
):
return None
# Get all non Eye inputs (scalars/matrices/vectors)
non_eye_inputs = list(set(inputs_to_mul) - set(eye_input))
non_eye_inputs = list(set(inputs_to_mul) - {eye_input})
return eye_input, non_eye_inputs
@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([det])
def rewrite_det_diag_from_eye_mul(fgraph, node):
def rewrite_det_diag_to_prod_diag(fgraph, node):
"""
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements.
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its
diagonal elements.
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix.
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices
that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to
make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar,
vector or a matrix.
Parameters
----------
......@@ -439,53 +480,45 @@ def rewrite_det_diag_from_eye_mul(fgraph, node):
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
potential_mul_input = node.inputs[0]
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input)
if eye_non_eye_inputs is None:
inputs = node.inputs[0]
# Check for use of pt.diag first
if (
inputs.owner
and isinstance(inputs.owner.op, AllocDiag)
and AllocDiag.is_offset_zero(inputs.owner)
):
diag_input = inputs.owner.inputs[0]
det_val = diag_input.prod(axis=-1)
return [det_val]
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(inputs)
if inputs_or_none is None:
return None
eye_input, non_eye_inputs = eye_non_eye_inputs
eye_input, non_eye_inputs = inputs_or_none
# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
# Checking if original x was scalar/vector/matrix
if useful_non_eye.type.broadcastable[-2:] == (True, True):
if non_eye_input.type.broadcastable[-2:] == (True, True):
# For scalar
det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0])
elif useful_non_eye.type.broadcastable[-2:] == (False, False):
det_val = non_eye_input.squeeze(axis=(-1, -2)) ** (eye_input.shape[0])
elif non_eye_input.type.broadcastable[-2:] == (False, False):
# For Matrix
det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
det_val = non_eye_input.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
else:
# For vector
det_val = useful_non_eye.prod(axis=(-1, -2))
det_val = non_eye_input.prod(axis=(-1, -2))
det_val = det_val.astype(node.outputs[0].type.dtype)
return [det_val]
arange = ARange("int64")
det_diag_from_diag = PatternNodeRewriter(
(
det,
(
advanced_set_subtensor,
(alloc, 0, "sh1", "sh2"),
"x",
(arange, 0, "stop", 1),
(arange, 0, "stop", 1),
),
),
(prod, "x"),
name="det_diag_from_diag",
allow_multiple_clients=True,
)
register_canonicalize(det_diag_from_diag)
register_stabilize(det_diag_from_diag)
register_specialize(det_diag_from_diag)
@register_canonicalize
@register_stabilize
@register_specialize
......
from pytensor import clone_replace
from pytensor.compile import optdb
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out
from pytensor.tensor.basic import AllocDiag
from pytensor.tensor.rewriting.basic import register_specialize
@node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node):
"""
This optimization expands internal graph of OpFromGraph.
Only performed if node.op.is_inline == True
Doing so can improve optimization at the cost of compilation speed.
"""
op = node.op
if not op.is_inline:
return False
new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
copy_stack_trace(op.inner_outputs, new_out)
return new_out
# We want to run this before the first merge optimizer
# and before the first scan optimizer.
optdb.register(
"inline_ofg_expansion",
in2out(inline_ofg_expansion),
"fast_compile",
"fast_run",
position=-0.01,
)
@register_specialize("inline_ofg")
@node_rewriter([AllocDiag])
def late_inline_OpFromGraph(fgraph, node):
"""
Inline `OpFromGraph` nodes.
OpFromGraph nodes are used to compactly represent the output of a function graph. Certain `Ops`, like, einsum,
diag, and kron, are implemented using pytensor `Op`s. As a result, their outputs are not a single `Op`, but a
graph. To allow rewrites to easily spot and manipulate these "composite functions", we use the `OpFromGraph` node.
This node is a thin wrapper around the output graph. It is not, however, meant to be included in the final
program, because it hides the inner graph from certain optimizations.
This rewrite specifies that all `OpFromGraph` nodes should be replaced by their inner graphs by setting the
`inplace=True` flag.
Parameters
----------
fgraph: FunctionGraph
The function graph being rewritten
node: Apply
Node of the function graph to be optimized
Returns
-------
"""
op = node.op
new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
copy_stack_trace(op.inner_outputs, new_out)
return new_out
......@@ -4,6 +4,7 @@ from functools import partial
import numpy as np
import pytest
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
......@@ -13,7 +14,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.ifelse import ifelse
from pytensor.raise_op import assert_op
from pytensor.tensor.type import dscalar, scalar, vector
from pytensor.tensor.type import dscalar, matrices, scalar, vector
@pytest.fixture(scope="module", autouse=True)
......@@ -209,3 +210,19 @@ def test_jax_checkandraise():
def set_test_value(x, v):
x.tag.test_value = v
return x
def test_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)
o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
out_fg = FunctionGraph([x, y, z], [out])
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_jax_and_py(out_fg, [xv, yv, zv])
......@@ -362,6 +362,8 @@ class TestBatchedVectorBSolveToMatrixBSolve:
ids=["block_diag", "kron"],
)
def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
rng = np.random.default_rng(sum(map(ord, "lift_through_linalg")))
if pytensor.config.floatX.endswith("32"):
pytest.skip("Test is flaky at half precision")
......@@ -371,6 +373,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
f1 = pytensor.function(
[A, B], X, mode=get_default_mode().including("local_lift_through_linalg")
)
f2 = pytensor.function(
[A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg")
)
......@@ -386,9 +389,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
assert len(f_ops) == 2
assert len(g_ops) == 1
test_vals = [
np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)
]
test_vals = [rng.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)]
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
......@@ -403,13 +404,18 @@ def test_det_diag_from_eye_mul(shape):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
y = pt.eye(7) * x
# Calculating determinant value using pt.linalg.det
z_det = pt.linalg.det(y)
# REWRITE TEST
f_rewritten = function([x], z_det, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Det) for node in nodes)
assert not any(
isinstance(node.op, Det) or isinstance(getattr(node.op, "core_op", None), Det)
for node in nodes
)
# NUMERIC VALUE TEST
if len(shape) == 0:
......@@ -418,6 +424,7 @@ def test_det_diag_from_eye_mul(shape):
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
det_val = np.linalg.det(x_test_matrix)
rewritten_val = f_rewritten(x_test)
......@@ -459,6 +466,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1():
x_diag = pt.eye(1, 1) * x
y = pt.linalg.det(x_diag)
f_rewritten = function([x], y, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Det) for node in nodes)
......@@ -468,6 +476,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1():
x_test_matrix = np.eye(1, 1) * x_test
det_val = np.linalg.det(x_test_matrix)
rewritten_val = f_rewritten(x_test)
assert_allclose(
det_val,
rewritten_val,
......
import pytest
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.compile.builders import OpFromGraph
@pytest.mark.skipif(
config.mode == "FAST_COMPILE",
reason="Rewrite is not applied in FAST_COMPILE mode",
)
def test_alloc_diag_inlined():
x = pt.tensor("x", shape=(None,))
z = pt.diag(x)
assert isinstance(z.owner.op, OpFromGraph)
f = pytensor.function([x], z)
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, OpFromGraph) for node in nodes)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论