提交 38731adb authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: Ricardo Vieira

move clone_replace to a separate file

上级 24453279
......@@ -73,7 +73,8 @@ from pytensor.configdefaults import config
__api_version__ = 1
# isort: off
from pytensor.graph.basic import Variable, clone_replace
from pytensor.graph.basic import Variable
from pytensor.graph.replace import clone_replace
# isort: on
......
......@@ -16,13 +16,13 @@ from pytensor.graph.basic import (
Constant,
NominalVariable,
Variable,
clone_replace,
graph_inputs,
io_connection_pattern,
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError
from pytensor.tensor.rewriting.shape import ShapeFeature
......
......@@ -7,9 +7,9 @@ from pytensor.graph.basic import (
Constant,
graph_inputs,
clone,
clone_replace,
ancestors,
)
from pytensor.graph.replace import clone_replace
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph
......
......@@ -1151,53 +1151,6 @@ def clone_get_equiv(
return memo
def clone_replace(
output: Collection[Variable],
replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
] = None,
**rebuild_kwds,
) -> List[Variable]:
"""Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding
substitutions.
Parameters
----------
output
PyTensor expression that represents the computational graph.
replace
Dictionary describing which subgraphs should be replaced by what.
rebuild_kwds
Keywords to `rebuild_collect_shared`.
"""
from pytensor.compile.function.pfunc import rebuild_collect_shared
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
if isinstance(replace, dict):
items = list(replace.items())
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
return cast(List[Variable], outs)
def general_toposort(
outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, List[T]]],
......
from typing import (
Collection,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from pytensor.graph.basic import Constant, Variable
def clone_replace(
output: Collection[Variable],
replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
] = None,
**rebuild_kwds,
) -> List[Variable]:
"""Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding
substitutions.
Parameters
----------
output
PyTensor expression that represents the computational graph.
replace
Dictionary describing which subgraphs should be replaced by what.
rebuild_kwds
Keywords to `rebuild_collect_shared`.
"""
from pytensor.compile.function.pfunc import rebuild_collect_shared
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
if isinstance(replace, dict):
items = list(replace.items())
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
return cast(List[Variable], outs)
......@@ -20,8 +20,9 @@ import pytensor.tensor as at
from pytensor import as_symbolic
from pytensor.compile import optdb
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from pytensor.graph.basic import Apply, Variable, is_in_ancestors
from pytensor.graph.op import _NoPythonOp
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.graph.type import HasDataType, HasShape
from pytensor.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
......
......@@ -6,8 +6,9 @@ import pytensor.tensor as at
from pytensor.compile.function.pfunc import construct_pfunc_ins_and_outs
from pytensor.compile.sharedvalue import SharedVariable, collect_new_shareds
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, Variable, clone_replace, graph_inputs
from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until
......
......@@ -65,13 +65,13 @@ from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefi
from pytensor.graph.basic import (
Apply,
Variable,
clone_replace,
equal_computations,
graph_inputs,
io_connection_pattern,
)
from pytensor.graph.features import NoOutputFromInplace
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.utils import InconsistencyError, MissingInputError
from pytensor.link.c.basic import CLinker
from pytensor.link.c.exceptions import MissingGXX
......
......@@ -18,7 +18,6 @@ from pytensor.graph.basic import (
Apply,
Constant,
Variable,
clone_replace,
equal_computations,
graph_inputs,
io_toposort,
......@@ -28,6 +27,7 @@ from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import compute_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
from pytensor.graph.type import HasShape
......
......@@ -14,14 +14,9 @@ from pytensor import scalar as aes
from pytensor import tensor as at
from pytensor.compile.profiling import ProfileStats
from pytensor.configdefaults import config
from pytensor.graph.basic import (
Constant,
Variable,
clone_replace,
equal_computations,
graph_inputs,
)
from pytensor.graph.basic import Constant, Variable, equal_computations, graph_inputs
from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.type import HasDataType
from pytensor.graph.utils import TestValueError
from pytensor.tensor.basic import AllocEmpty, cast
......
......@@ -4,7 +4,7 @@ from itertools import count
import numpy as np
import pytest
from pytensor import config, function, shared
from pytensor import shared
from pytensor import tensor as at
from pytensor.graph.basic import (
Apply,
......@@ -15,7 +15,6 @@ from pytensor.graph.basic import (
as_string,
clone,
clone_get_equiv,
clone_replace,
equal_computations,
general_toposort,
get_var_by_name,
......@@ -30,18 +29,9 @@ from pytensor.graph.basic import (
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.tensor.math import max_and_argmax
from pytensor.tensor.type import (
TensorType,
dvector,
fvector,
iscalars,
matrix,
scalars,
vector,
)
from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.var import TensorVariable
from tests import unittest_tools as utt
from tests.graph.utils import MyInnerGraphOp
......@@ -557,131 +547,6 @@ def test_get_var_by_name():
assert res == exp_res
class TestCloneReplace:
def test_cloning_no_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y in f2_inp
def test_cloning_no_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=None, rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y not in f2_inp
def test_cloning_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_not_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_cloning_replace_not_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=False, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_clone(self):
def test(x, y, mention_y):
if mention_y:
d = 0.1 + 0 * y
else:
d = 0.1
out = clone_replace(y, replace={x: x + d})
return function([], out)()
x = shared(np.asarray(0.0, dtype=config.floatX))
utt.assert_allclose(
test(x, at.sum((x + 1) ** 2), mention_y=False), 1.21000003815
)
utt.assert_allclose(
test(x, at.sum((x + 1) ** 2), mention_y=True), 1.21000003815
)
def test_clone_new_inputs():
"""Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes."""
......
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import config, function, shared
from pytensor.graph.basic import graph_inputs
from pytensor.graph.replace import clone_replace
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable
class TestCloneReplace:
def test_cloning_no_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y in f2_inp
def test_cloning_no_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=None, rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y not in f2_inp
def test_cloning_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_not_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_cloning_replace_not_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=False, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_clone(self):
def test(x, y, mention_y):
if mention_y:
d = 0.1 + 0 * y
else:
d = 0.1
out = clone_replace(y, replace={x: x + d})
return function([], out)()
x = shared(np.asarray(0.0, dtype=config.floatX))
utt.assert_allclose(
test(x, pt.sum((x + 1) ** 2), mention_y=False), 1.21000003815
)
utt.assert_allclose(
test(x, pt.sum((x + 1) ** 2), mention_y=True), 1.21000003815
)
import numpy as np
from pytensor.graph.basic import (
Apply,
Constant,
NominalVariable,
Variable,
clone_replace,
)
from pytensor.graph.basic import Apply, Constant, NominalVariable, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.type import Type
......
......@@ -9,8 +9,9 @@ from pytensor.compile.io import In
from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad, jacobian
from pytensor.graph.basic import clone_replace, equal_computations
from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge
from pytensor.scan.utils import until
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论