提交 79a4bc1e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Clone inner graphs of Ops that will be mutated during compilation

上级 ef0950a9
......@@ -872,7 +872,7 @@ class OpFromGraph(Op, HasInnerGraph):
def clone(self):
res = copy(self)
res.fgraph = res.fgraph.clone()
res.fgraph = res.fgraph.clone(clone_inner_graphs=True)
return res
def perform(self, node, inputs, outputs):
......
......@@ -838,9 +838,13 @@ class FunctionGraph(MetaObject):
def __repr__(self):
return f"FunctionGraph({', '.join(graph_as_string(self.inputs, self.outputs))})"
def clone(self, check_integrity=True) -> "FunctionGraph":
def clone(
self, check_integrity=True, clone_inner_graphs: bool = False
) -> "FunctionGraph":
"""Clone the graph."""
return self.clone_get_equiv(check_integrity)[0]
return self.clone_get_equiv(
check_integrity, clone_inner_graphs=clone_inner_graphs
)[0]
def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True, **kwargs
......
......@@ -1521,7 +1521,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def clone(self) -> "Scan":
res = copy(self)
res.fgraph = res.fgraph.clone()
res.fgraph = res.fgraph.clone(clone_inner_graphs=True)
return res
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
......
......@@ -211,7 +211,7 @@ class ScipyWrapperOp(Op, HasInnerGraph):
def clone(self):
copy_op = copy(self)
copy_op.fgraph = self.fgraph.clone()
copy_op.fgraph = self.fgraph.clone(clone_inner_graphs=True)
return copy_op
def prepare_node(
......
import numpy as np
import pytest
from pytensor import OpFromGraph, config, function, ifelse
from pytensor import Mode, OpFromGraph, config, function, ifelse, scan
from pytensor import tensor as pt
from pytensor.compile import ViewOp
from pytensor.graph import vectorize_graph
from pytensor.raise_op import assert_op
from pytensor.scalar import Add
from pytensor.scan.op import Scan
from pytensor.tensor import dmatrix, dtensor3, matrix
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.slinalg import Cholesky
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -197,3 +200,34 @@ def test_check_and_raise():
out = assert_op(x.sum(), np.array(True))
compare_numba_and_py([x], out, [x_test_value])
def test_ofg_with_inner_scan_rewrite():
# Regression test where inner scan would be mutated when compiling outer OFG
ys = pt.tensor("ys", shape=(5, 3, 3))
xs = scan(
lambda y: pt.linalg.cholesky(y),
sequences=[ys],
return_updates=False,
mode=Mode(optimizer=None),
)
xs_ofg = OpFromGraph([ys], [xs])(ys)
fn = function([ys], xs_ofg, mode="NUMBA")
# Check that we have a BlockwiseWithCoreShape in the inner Scan
fn_ofg_op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(fn_ofg_op, OpFromGraph)
fn_scan_op = fn_ofg_op.fgraph.outputs[0].owner.op
assert isinstance(fn_scan_op, Scan)
fn_cholesky_op = fn_scan_op.fgraph.outputs[0].owner.op
assert isinstance(fn_cholesky_op, BlockwiseWithCoreShape)
assert isinstance(fn_cholesky_op.core_op, Cholesky)
# Check original Ops aren't modified
ofg_op = xs_ofg.owner.op
assert isinstance(ofg_op, OpFromGraph)
scan_op = ofg_op.fgraph.outputs[0].owner.op
assert isinstance(scan_op, Scan)
cholesky_op = scan_op.fgraph.outputs[0].owner.op
assert isinstance(cholesky_op, Blockwise)
assert isinstance(cholesky_op.core_op, Cholesky)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论