提交 38731adb authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: Ricardo Vieira

move clone_replace to a separate file

上级 24453279
...@@ -73,7 +73,8 @@ from pytensor.configdefaults import config ...@@ -73,7 +73,8 @@ from pytensor.configdefaults import config
__api_version__ = 1 __api_version__ = 1
# isort: off # isort: off
from pytensor.graph.basic import Variable, clone_replace from pytensor.graph.basic import Variable
from pytensor.graph.replace import clone_replace
# isort: on # isort: on
......
...@@ -16,13 +16,13 @@ from pytensor.graph.basic import ( ...@@ -16,13 +16,13 @@ from pytensor.graph.basic import (
Constant, Constant,
NominalVariable, NominalVariable,
Variable, Variable,
clone_replace,
graph_inputs, graph_inputs,
io_connection_pattern, io_connection_pattern,
) )
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.rewriting.shape import ShapeFeature
......
...@@ -7,9 +7,9 @@ from pytensor.graph.basic import ( ...@@ -7,9 +7,9 @@ from pytensor.graph.basic import (
Constant, Constant,
graph_inputs, graph_inputs,
clone, clone,
clone_replace,
ancestors, ancestors,
) )
from pytensor.graph.replace import clone_replace
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
......
...@@ -1151,53 +1151,6 @@ def clone_get_equiv( ...@@ -1151,53 +1151,6 @@ def clone_get_equiv(
return memo return memo
def clone_replace(
output: Collection[Variable],
replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
] = None,
**rebuild_kwds,
) -> List[Variable]:
"""Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding
substitutions.
Parameters
----------
output
PyTensor expression that represents the computational graph.
replace
Dictionary describing which subgraphs should be replaced by what.
rebuild_kwds
Keywords to `rebuild_collect_shared`.
"""
from pytensor.compile.function.pfunc import rebuild_collect_shared
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
if isinstance(replace, dict):
items = list(replace.items())
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
return cast(List[Variable], outs)
def general_toposort( def general_toposort(
outputs: Iterable[T], outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, List[T]]], deps: Callable[[T], Union[OrderedSet, List[T]]],
......
from typing import (
Collection,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from pytensor.graph.basic import Constant, Variable
def clone_replace(
output: Collection[Variable],
replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
] = None,
**rebuild_kwds,
) -> List[Variable]:
"""Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding
substitutions.
Parameters
----------
output
PyTensor expression that represents the computational graph.
replace
Dictionary describing which subgraphs should be replaced by what.
rebuild_kwds
Keywords to `rebuild_collect_shared`.
"""
from pytensor.compile.function.pfunc import rebuild_collect_shared
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
if isinstance(replace, dict):
items = list(replace.items())
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
return cast(List[Variable], outs)
...@@ -20,8 +20,9 @@ import pytensor.tensor as at ...@@ -20,8 +20,9 @@ import pytensor.tensor as at
from pytensor import as_symbolic from pytensor import as_symbolic
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, clone_replace, is_in_ancestors from pytensor.graph.basic import Apply, Variable, is_in_ancestors
from pytensor.graph.op import _NoPythonOp from pytensor.graph.op import _NoPythonOp
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.graph.type import HasDataType, HasShape from pytensor.graph.type import HasDataType, HasShape
from pytensor.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast from pytensor.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
......
...@@ -6,8 +6,9 @@ import pytensor.tensor as at ...@@ -6,8 +6,9 @@ import pytensor.tensor as at
from pytensor.compile.function.pfunc import construct_pfunc_ins_and_outs from pytensor.compile.function.pfunc import construct_pfunc_ins_and_outs
from pytensor.compile.sharedvalue import SharedVariable, collect_new_shareds from pytensor.compile.sharedvalue import SharedVariable, collect_new_shareds
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, Variable, clone_replace, graph_inputs from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.utils import MissingInputError, TestValueError from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until from pytensor.scan.utils import expand_empty, safe_new, until
......
...@@ -65,13 +65,13 @@ from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefi ...@@ -65,13 +65,13 @@ from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefi
from pytensor.graph.basic import ( from pytensor.graph.basic import (
Apply, Apply,
Variable, Variable,
clone_replace,
equal_computations, equal_computations,
graph_inputs, graph_inputs,
io_connection_pattern, io_connection_pattern,
) )
from pytensor.graph.features import NoOutputFromInplace from pytensor.graph.features import NoOutputFromInplace
from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.graph.utils import InconsistencyError, MissingInputError
from pytensor.link.c.basic import CLinker from pytensor.link.c.basic import CLinker
from pytensor.link.c.exceptions import MissingGXX from pytensor.link.c.exceptions import MissingGXX
......
...@@ -18,7 +18,6 @@ from pytensor.graph.basic import ( ...@@ -18,7 +18,6 @@ from pytensor.graph.basic import (
Apply, Apply,
Constant, Constant,
Variable, Variable,
clone_replace,
equal_computations, equal_computations,
graph_inputs, graph_inputs,
io_toposort, io_toposort,
...@@ -28,6 +27,7 @@ from pytensor.graph.destroyhandler import DestroyHandler ...@@ -28,6 +27,7 @@ from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import ReplaceValidate from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import compute_test_value from pytensor.graph.op import compute_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape
......
...@@ -14,14 +14,9 @@ from pytensor import scalar as aes ...@@ -14,14 +14,9 @@ from pytensor import scalar as aes
from pytensor import tensor as at from pytensor import tensor as at
from pytensor.compile.profiling import ProfileStats from pytensor.compile.profiling import ProfileStats
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import ( from pytensor.graph.basic import Constant, Variable, equal_computations, graph_inputs
Constant,
Variable,
clone_replace,
equal_computations,
graph_inputs,
)
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.type import HasDataType from pytensor.graph.type import HasDataType
from pytensor.graph.utils import TestValueError from pytensor.graph.utils import TestValueError
from pytensor.tensor.basic import AllocEmpty, cast from pytensor.tensor.basic import AllocEmpty, cast
......
...@@ -4,7 +4,7 @@ from itertools import count ...@@ -4,7 +4,7 @@ from itertools import count
import numpy as np import numpy as np
import pytest import pytest
from pytensor import config, function, shared from pytensor import shared
from pytensor import tensor as at from pytensor import tensor as at
from pytensor.graph.basic import ( from pytensor.graph.basic import (
Apply, Apply,
...@@ -15,7 +15,6 @@ from pytensor.graph.basic import ( ...@@ -15,7 +15,6 @@ from pytensor.graph.basic import (
as_string, as_string,
clone, clone,
clone_get_equiv, clone_get_equiv,
clone_replace,
equal_computations, equal_computations,
general_toposort, general_toposort,
get_var_by_name, get_var_by_name,
...@@ -30,18 +29,9 @@ from pytensor.graph.basic import ( ...@@ -30,18 +29,9 @@ from pytensor.graph.basic import (
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.tensor.math import max_and_argmax from pytensor.tensor.math import max_and_argmax
from pytensor.tensor.type import ( from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
TensorType,
dvector,
fvector,
iscalars,
matrix,
scalars,
vector,
)
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.var import TensorVariable from pytensor.tensor.var import TensorVariable
from tests import unittest_tools as utt
from tests.graph.utils import MyInnerGraphOp from tests.graph.utils import MyInnerGraphOp
...@@ -557,131 +547,6 @@ def test_get_var_by_name(): ...@@ -557,131 +547,6 @@ def test_get_var_by_name():
assert res == exp_res assert res == exp_res
class TestCloneReplace:
def test_cloning_no_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y in f2_inp
def test_cloning_no_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=None, rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y not in f2_inp
def test_cloning_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_not_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_cloning_replace_not_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=False, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_clone(self):
def test(x, y, mention_y):
if mention_y:
d = 0.1 + 0 * y
else:
d = 0.1
out = clone_replace(y, replace={x: x + d})
return function([], out)()
x = shared(np.asarray(0.0, dtype=config.floatX))
utt.assert_allclose(
test(x, at.sum((x + 1) ** 2), mention_y=False), 1.21000003815
)
utt.assert_allclose(
test(x, at.sum((x + 1) ** 2), mention_y=True), 1.21000003815
)
def test_clone_new_inputs(): def test_clone_new_inputs():
"""Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes.""" """Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes."""
......
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import config, function, shared
from pytensor.graph.basic import graph_inputs
from pytensor.graph.replace import clone_replace
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable
class TestCloneReplace:
def test_cloning_no_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y in f2_inp
def test_cloning_no_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=None, rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y not in f2_inp
def test_cloning_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_not_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_cloning_replace_not_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=False, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_clone(self):
def test(x, y, mention_y):
if mention_y:
d = 0.1 + 0 * y
else:
d = 0.1
out = clone_replace(y, replace={x: x + d})
return function([], out)()
x = shared(np.asarray(0.0, dtype=config.floatX))
utt.assert_allclose(
test(x, pt.sum((x + 1) ** 2), mention_y=False), 1.21000003815
)
utt.assert_allclose(
test(x, pt.sum((x + 1) ** 2), mention_y=True), 1.21000003815
)
import numpy as np import numpy as np
from pytensor.graph.basic import ( from pytensor.graph.basic import Apply, Constant, NominalVariable, Variable
Apply,
Constant,
NominalVariable,
Variable,
clone_replace,
)
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.type import Type from pytensor.graph.type import Type
......
...@@ -9,8 +9,9 @@ from pytensor.compile.io import In ...@@ -9,8 +9,9 @@ from pytensor.compile.io import In
from pytensor.compile.mode import get_default_mode from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad, jacobian from pytensor.gradient import grad, jacobian
from pytensor.graph.basic import clone_replace, equal_computations from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge
from pytensor.scan.utils import until from pytensor.scan.utils import until
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论