提交 c09d92b1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clone inner-graph before compiling in OpFromGraph

上级 88f02990
...@@ -5,8 +5,8 @@ from functools import partial ...@@ -5,8 +5,8 @@ from functools import partial
from typing import List, Optional, Sequence, cast from typing import List, Optional, Sequence, cast
import aesara.tensor as at import aesara.tensor as at
from aesara import function
from aesara.compile.function.pfunc import rebuild_collect_shared from aesara.compile.function.pfunc import rebuild_collect_shared
from aesara.compile.function.types import orig_function
from aesara.compile.mode import optdb from aesara.compile.mode import optdb
from aesara.compile.sharedvalue import SharedVariable from aesara.compile.sharedvalue import SharedVariable
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -326,7 +326,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -326,7 +326,7 @@ class OpFromGraph(Op, HasInnerGraph):
name name
A name for debugging purposes. A name for debugging purposes.
kwargs kwargs
Check :func:`orig_function` for more arguments, only works when not Check :func:`aesara.function` for more arguments, only works when not
inline. inline.
""" """
...@@ -903,7 +903,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -903,7 +903,7 @@ class OpFromGraph(Op, HasInnerGraph):
if getattr(self, "_fn", None) is not None: if getattr(self, "_fn", None) is not None:
return self._fn return self._fn
self._fn = orig_function(self.inner_inputs, self.inner_outputs, **self.kwargs) self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
self._fn.trust_input = True self._fn.trust_input = True
return self._fn return self._fn
......
...@@ -541,6 +541,25 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -541,6 +541,25 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert np.array_equal(res_2, 1.0) assert np.array_equal(res_2, 1.0)
def test_outputs_consistency(self):
"""Make sure that `OpFromGraph.fn` doesn't change the value of `OpFromGraph.inner_outputs`."""
x = scalar("x")
op = OpFromGraph([x], [x**2 / x], mode="FAST_RUN")
# Confirm that the inner-graph is as expected
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])
# These outputs of the compiled `op.fgraph` should differ from the
# original, uncompiled `op.fgraph` outputs
fn = op.fn
new_inputs = fn.maker.fgraph.inputs
new_outputs = fn.maker.fgraph.outputs
assert not equal_computations(new_outputs, [x**2 / x], new_inputs, [x])
# The original `op.fgraph` outputs should stay the same, though
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])
@config.change_flags(floatX="float64") @config.change_flags(floatX="float64")
def test_debugprint(): def test_debugprint():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论