提交 862aec54 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Rename theano.scan.utils.clone to theano.graph.basic.clone_replace

上级 396ecc10
...@@ -58,6 +58,6 @@ There are also some top-level imports that you might find more convenient: ...@@ -58,6 +58,6 @@ There are also some top-level imports that you might find more convenient:
Works like :func:`tensor.dot` for both sparse and dense matrix products Works like :func:`tensor.dot` for both sparse and dense matrix products
.. autofunction:: theano.clone .. autofunction:: theano.clone_replace
.. autofunction:: theano.sparse_grad .. autofunction:: theano.sparse_grad
...@@ -52,7 +52,7 @@ def test_rop_lop(): ...@@ -52,7 +52,7 @@ def test_rop_lop():
raised = False raised = False
try: try:
theano.gradient.Rop(theano.clone(y, replace={mx: break_op(mx)}), mx, mv) theano.gradient.Rop(theano.clone_replace(y, replace={mx: break_op(mx)}), mx, mv)
except ValueError: except ValueError:
raised = True raised = True
if not raised: if not raised:
......
...@@ -37,6 +37,7 @@ from theano.gradient import ( ...@@ -37,6 +37,7 @@ from theano.gradient import (
hessian, hessian,
jacobian, jacobian,
) )
from theano.graph.basic import clone_replace
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
from theano.scan.basic import scan from theano.scan.basic import scan
from theano.scan.op import Scan from theano.scan.op import Scan
...@@ -1065,7 +1066,7 @@ class TestScan: ...@@ -1065,7 +1066,7 @@ class TestScan:
x0[0] = vx0 x0[0] = vx0
x0 = tt.constant(x0) x0 = tt.constant(x0)
to_replace = outputs[0].owner.inputs[0].owner.inputs[1] to_replace = outputs[0].owner.inputs[0].owner.inputs[1]
outputs = theano.clone(outputs, replace=[(to_replace, x0)]) outputs = clone_replace(outputs, replace=[(to_replace, x0)])
mode = get_mode(None).including("inplace") mode = get_mode(None).including("inplace")
f9 = theano.function([], outputs, updates=updates, mode=mode) f9 = theano.function([], outputs, updates=updates, mode=mode)
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)] scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
...@@ -1981,7 +1982,7 @@ class TestScan: ...@@ -1981,7 +1982,7 @@ class TestScan:
z = theano.shared(0.25) z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=None, strict=True, share_inputs=True) f2 = clone_replace(f1, replace=None, strict=True, share_inputs=True)
f2_inp = theano.graph.basic.graph_inputs([f2]) f2_inp = theano.graph.basic.graph_inputs([f2])
assert z in f2_inp assert z in f2_inp
...@@ -1997,7 +1998,7 @@ class TestScan: ...@@ -1997,7 +1998,7 @@ class TestScan:
z = theano.shared(0.25) z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=None, strict=True, share_inputs=False) f2 = clone_replace(f1, replace=None, strict=True, share_inputs=False)
f2_inp = theano.graph.basic.graph_inputs([f2]) f2_inp = theano.graph.basic.graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
...@@ -2014,7 +2015,7 @@ class TestScan: ...@@ -2014,7 +2015,7 @@ class TestScan:
z = theano.shared(0.25) z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = theano.clone( f2 = clone_replace(
f1, replace=OrderedDict([(y, y2)]), strict=True, share_inputs=True f1, replace=OrderedDict([(y, y2)]), strict=True, share_inputs=True
) )
f2_inp = theano.graph.basic.graph_inputs([f2]) f2_inp = theano.graph.basic.graph_inputs([f2])
...@@ -2032,7 +2033,7 @@ class TestScan: ...@@ -2032,7 +2033,7 @@ class TestScan:
z = theano.shared(0.25) z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = theano.clone( f2 = clone_replace(
f1, replace=OrderedDict([(y, y2)]), strict=False, share_inputs=True f1, replace=OrderedDict([(y, y2)]), strict=False, share_inputs=True
) )
f2_inp = theano.graph.basic.graph_inputs([f2]) f2_inp = theano.graph.basic.graph_inputs([f2])
...@@ -2050,7 +2051,7 @@ class TestScan: ...@@ -2050,7 +2051,7 @@ class TestScan:
z = theano.shared(0.25) z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=[(y, y2)], strict=True, share_inputs=False) f2 = clone_replace(f1, replace=[(y, y2)], strict=True, share_inputs=False)
f2_inp = theano.graph.basic.graph_inputs([f2]) f2_inp = theano.graph.basic.graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
assert x not in f2_inp assert x not in f2_inp
...@@ -2066,7 +2067,7 @@ class TestScan: ...@@ -2066,7 +2067,7 @@ class TestScan:
z = theano.shared(0.25) z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=[(y, y2)], strict=False, share_inputs=False) f2 = clone_replace(f1, replace=[(y, y2)], strict=False, share_inputs=False)
f2_inp = theano.graph.basic.graph_inputs([f2]) f2_inp = theano.graph.basic.graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
assert x not in f2_inp assert x not in f2_inp
...@@ -4260,7 +4261,7 @@ class TestScan: ...@@ -4260,7 +4261,7 @@ class TestScan:
d = 0.1 + 0 * y d = 0.1 + 0 * y
else: else:
d = 0.1 d = 0.1
out = theano.clone(y, replace={x: x + d}) out = clone_replace(y, replace={x: x + d})
# theano.printing.debugprint(out) # theano.printing.debugprint(out)
return theano.function([], out)() return theano.function([], out)()
......
...@@ -94,7 +94,7 @@ class TestGaussNewton: ...@@ -94,7 +94,7 @@ class TestGaussNewton:
# during certain iterations of CG in the HF algorithm. There, # during certain iterations of CG in the HF algorithm. There,
# it's in fact `pi + current update proposal`. For simplicity, # it's in fact `pi + current update proposal`. For simplicity,
# I just multiply by 2 here. # I just multiply by 2 here.
cost_ = theano.clone(cost, replace={pi: 2 * pi for pi in params}) cost_ = theano.clone_replace(cost, replace={pi: 2 * pi for pi in params})
# Compute Gauss-Newton-Matrix times some vector `v` which is `p` in CG, # Compute Gauss-Newton-Matrix times some vector `v` which is `p` in CG,
# but for simplicity, I just take the parameters vector because it's # but for simplicity, I just take the parameters vector because it's
......
...@@ -1302,7 +1302,7 @@ def test_gt_grad(): ...@@ -1302,7 +1302,7 @@ def test_gt_grad():
W = theano.shared(value=W_values, name="weights") W = theano.shared(value=W_values, name="weights")
correct_score = tt.dot(input_, W) correct_score = tt.dot(input_, W)
wrong_input = vector(dtype=floatX) wrong_input = vector(dtype=floatX)
wrong_score = theano.clone(correct_score, {input_: wrong_input}) wrong_score = theano.clone_replace(correct_score, {input_: wrong_input})
# Hinge loss # Hinge loss
scores = tt.ones_like(correct_score) - correct_score + wrong_score scores = tt.ones_like(correct_score) - correct_score + wrong_score
cost = (scores * (scores > 0)).sum() cost = (scores * (scores > 0)).sum()
......
...@@ -119,7 +119,9 @@ class RopLopChecker: ...@@ -119,7 +119,9 @@ class RopLopChecker:
assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}" assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}"
self.check_nondiff_rop(theano.clone(y, replace={self.mx: break_op(self.mx)})) self.check_nondiff_rop(
theano.clone_replace(y, replace={self.mx: break_op(self.mx)})
)
vv = np.asarray(self.rng.uniform(size=out_shape), theano.config.floatX) vv = np.asarray(self.rng.uniform(size=out_shape), theano.config.floatX)
yv = Lop(y, self.mx, self.v) yv = Lop(y, self.mx, self.v)
...@@ -157,7 +159,11 @@ class RopLopChecker: ...@@ -157,7 +159,11 @@ class RopLopChecker:
assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}" assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}"
try: try:
Rop(theano.clone(y, replace={self.x: break_op(self.x)}), self.x, self.v) Rop(
theano.clone_replace(y, replace={self.x: break_op(self.x)}),
self.x,
self.v,
)
except ValueError: except ValueError:
pytest.skip( pytest.skip(
"Rop does not handle non-differentiable inputs " "Rop does not handle non-differentiable inputs "
......
...@@ -147,7 +147,10 @@ def sparse_grad(var): ...@@ -147,7 +147,10 @@ def sparse_grad(var):
import theano.tensor.random.var import theano.tensor.random.var
from theano.scan import checkpoints, clone, foldl, foldr, map, reduce, scan from theano.graph.basic import clone_replace
from theano.scan import checkpoints
from theano.scan.basic import scan
from theano.scan.views import foldl, foldr, map, reduce
# Some config variables are registered by submodules. Only after all those imports # Some config variables are registered by submodules. Only after all those imports
......
...@@ -781,12 +781,12 @@ class OpFromGraph(Op): ...@@ -781,12 +781,12 @@ class OpFromGraph(Op):
# Clone the output shape so that shape are computed from outer inputs. # Clone the output shape so that shape are computed from outer inputs.
# Note: # Note:
# Here we can do it more simply like: # Here we can do it more simply like:
# ret = [theano.clone(shp, replace=repl) for shp in out_shp] # ret = [theano.clone_replace(shp, replace=repl) for shp in out_shp]
# But doing it multiple time could duplicate common subgraph between # But doing it multiple time could duplicate common subgraph between
# each shape call. Theano optimizer will clean this up later, but this # each shape call. Theano optimizer will clean this up later, but this
# will ask extra work to the optimizer. # will ask extra work to the optimizer.
repl = dict(zip(self.local_inputs, node.inputs)) repl = dict(zip(self.local_inputs, node.inputs))
cloned = theano.clone(reduce(tuple.__add__, out_shp), replace=repl) cloned = theano.clone_replace(reduce(tuple.__add__, out_shp), replace=repl)
ret = [] ret = []
used = 0 used = 0
for i in range(len(out_shp)): for i in range(len(out_shp)):
...@@ -824,7 +824,7 @@ def inline_ofg_expansion(fgraph, node): ...@@ -824,7 +824,7 @@ def inline_ofg_expansion(fgraph, node):
return False return False
if not op.is_inline: if not op.is_inline:
return False return False
return theano.clone( return theano.clone_replace(
op.local_outputs, {u: v for u, v in zip(node.op.local_inputs, node.inputs)} op.local_outputs, {u: v for u, v in zip(node.op.local_inputs, node.inputs)}
) )
......
...@@ -454,9 +454,9 @@ def pfunc( ...@@ -454,9 +454,9 @@ def pfunc(
"has no effect. One way to modify an input `x` to a function " "has no effect. One way to modify an input `x` to a function "
"evaluating f(x) is to define a new input `y` and use " "evaluating f(x) is to define a new input `y` and use "
"`theano.function([y], f(x), givens={x: g(y)})`. Another " "`theano.function([y], f(x), givens={x: g(y)})`. Another "
"solution consists in using `theano.clone`, e.g. like this: " "solution consists in using `theano.clone_replace`, e.g. like this: "
"`theano.function([x], " "`theano.function([x], "
"theano.clone(f(x), replace={x: g(x)}))`." "theano.clone_replace(f(x), replace={x: g(x)}))`."
) )
# Extend the outputs with the updates on input variables so they are also # Extend the outputs with the updates on input variables so they are also
......
...@@ -151,7 +151,7 @@ from theano.gpuarray.type import ( ...@@ -151,7 +151,7 @@ from theano.gpuarray.type import (
move_to_gpu, move_to_gpu,
) )
from theano.graph import toolbox from theano.graph import toolbox
from theano.graph.basic import Constant, Variable, applys_between from theano.graph.basic import Constant, Variable, applys_between, clone_replace
from theano.graph.fg import FunctionGraph from theano.graph.fg import FunctionGraph
from theano.graph.opt import ( from theano.graph.opt import (
GlobalOptimizer, GlobalOptimizer,
...@@ -165,7 +165,6 @@ from theano.link.c.basic import CLinker ...@@ -165,7 +165,6 @@ from theano.link.c.basic import CLinker
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
from theano.scalar.basic import Cast, Pow, Scalar, log, neg, true_div from theano.scalar.basic import Cast, Pow, Scalar, log, neg, true_div
from theano.scalar.basic_scipy import Erfcinv, Erfinv from theano.scalar.basic_scipy import Erfcinv, Erfinv
from theano.scan import utils
from theano.scan.op import Scan from theano.scan.op import Scan
from theano.scan.opt import ScanInplaceOptimizer from theano.scan.opt import ScanInplaceOptimizer
from theano.tensor.basic import ( from theano.tensor.basic import (
...@@ -2661,7 +2660,7 @@ def gpu_reconstruct_graph(inputs, outputs, tag=None): ...@@ -2661,7 +2660,7 @@ def gpu_reconstruct_graph(inputs, outputs, tag=None):
givens = {} givens = {}
for nw_x, x in zip(nw_inputs, inputs): for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x givens[x] = nw_x
nw_outputs = utils.clone(outputs, replace=givens) nw_outputs = clone_replace(outputs, replace=givens)
return (nw_inputs, nw_outputs) return (nw_inputs, nw_outputs)
...@@ -2689,7 +2688,7 @@ def local_gpua_scan_to_gpua(fgraph, op, context_name, inputs, outputs): ...@@ -2689,7 +2688,7 @@ def local_gpua_scan_to_gpua(fgraph, op, context_name, inputs, outputs):
scan_outs += [op.outputs[-1]] scan_outs += [op.outputs[-1]]
else: else:
scan_outs = [safe_to_gpu(x, context_name) for x in op.outputs] scan_outs = [safe_to_gpu(x, context_name) for x in op.outputs]
scan_outs = utils.clone( scan_outs = clone_replace(
scan_outs, replace=list(zip(op.inputs, (safe_to_cpu(x) for x in scan_ins))) scan_outs, replace=list(zip(op.inputs, (safe_to_cpu(x) for x in scan_ins)))
) )
......
...@@ -16,6 +16,7 @@ from typing import ( ...@@ -16,6 +16,7 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Set, Set,
Tuple,
TypeVar, TypeVar,
Union, Union,
) )
...@@ -863,7 +864,12 @@ def applys_between( ...@@ -863,7 +864,12 @@ def applys_between(
) )
def clone(inputs, outputs, copy_inputs=True, copy_orphans=None): def clone(
inputs: Collection[Variable],
outputs: Collection[Variable],
copy_inputs: bool = True,
copy_orphans: Optional[bool] = None,
) -> Tuple[Collection[Variable], Collection[Variable]]:
"""Copies the sub-graph contained between inputs and outputs. """Copies the sub-graph contained between inputs and outputs.
Parameters Parameters
...@@ -898,7 +904,13 @@ def clone(inputs, outputs, copy_inputs=True, copy_orphans=None): ...@@ -898,7 +904,13 @@ def clone(inputs, outputs, copy_inputs=True, copy_orphans=None):
return [equiv[input] for input in inputs], [equiv[output] for output in outputs] return [equiv[input] for input in inputs], [equiv[output] for output in outputs]
def clone_get_equiv(inputs, outputs, copy_inputs=True, copy_orphans=True, memo=None): def clone_get_equiv(
inputs: Collection[Variable],
outputs: Collection[Variable],
copy_inputs: bool = True,
copy_orphans: bool = True,
memo: Optional[Dict[Variable, Variable]] = None,
):
""" """
Return a dictionary that maps from Variable and Apply nodes in the Return a dictionary that maps from Variable and Apply nodes in the
original graph to a new node (a clone) in a new graph. original graph to a new node (a clone) in a new graph.
...@@ -960,6 +972,60 @@ def clone_get_equiv(inputs, outputs, copy_inputs=True, copy_orphans=True, memo=N ...@@ -960,6 +972,60 @@ def clone_get_equiv(inputs, outputs, copy_inputs=True, copy_orphans=True, memo=N
return memo return memo
def clone_replace(
output: Collection[Variable],
replace: Optional[Dict[Variable, Variable]] = None,
strict: bool = True,
share_inputs: bool = True,
) -> Collection[Variable]:
"""Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding
substitutions.
Parameters
----------
output : Theano Variables (or Theano expressions)
Theano expression that represents the computational graph.
replace : dict
Dictionary describing which subgraphs should be replaced by what.
share_inputs : bool
If True, use the same inputs (and shared variables) as the original
graph. If False, clone them. Note that cloned shared variables still
use the same underlying storage, so they will always have the same
value.
"""
from theano.compile.function.pfunc import rebuild_collect_shared
if isinstance(replace, dict):
items = list(replace.items())
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(
(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
)
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared(
output, [], tmp_replace, [], strict, share_inputs
)
# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(
_outs, [], new_replace, [], strict, share_inputs
)
return outs
def general_toposort( def general_toposort(
outputs: Iterable[T], outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, List[T]]], deps: Callable[[T], Union[OrderedSet, List[T]]],
......
...@@ -19,10 +19,9 @@ import numpy as np ...@@ -19,10 +19,9 @@ import numpy as np
import theano.tensor import theano.tensor
from theano.compile import optdb from theano.compile import optdb
from theano.configdefaults import config from theano.configdefaults import config
from theano.graph.basic import Apply, Variable, is_in_ancestors from theano.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from theano.graph.op import _NoPythonOp from theano.graph.op import _NoPythonOp
from theano.graph.opt import GlobalOptimizer, local_optimizer from theano.graph.opt import GlobalOptimizer, local_optimizer
from theano.scan.utils import clone
from theano.tensor import opt from theano.tensor import opt
from theano.tensor.type import TensorType from theano.tensor.type import TensorType
...@@ -540,8 +539,8 @@ def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node): ...@@ -540,8 +539,8 @@ def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node):
false_ins.append(x) false_ins.append(x)
true_eval = mop(*true_ins, **dict(return_list=True)) true_eval = mop(*true_ins, **dict(return_list=True))
false_eval = mop(*false_ins, **dict(return_list=True)) false_eval = mop(*false_ins, **dict(return_list=True))
# true_eval = clone(outs, replace = dict(zip(node.outputs, ts))) # true_eval = clone_replace(outs, replace = dict(zip(node.outputs, ts)))
# false_eval = clone(outs, replace = dict(zip(node.outputs, fs))) # false_eval = clone_replace(outs, replace = dict(zip(node.outputs, fs)))
nw_outs = ifelse(node.inputs[0], true_eval, false_eval, return_list=True) nw_outs = ifelse(node.inputs[0], true_eval, false_eval, return_list=True)
return nw_outs return nw_outs
...@@ -641,7 +640,7 @@ class CondMerge(GlobalOptimizer): ...@@ -641,7 +640,7 @@ class CondMerge(GlobalOptimizer):
) )
print("here") print("here")
new_outs = new_ifelse(*new_ins, **dict(return_list=True)) new_outs = new_ifelse(*new_ins, **dict(return_list=True))
new_outs = [clone(x) for x in new_outs] new_outs = [clone_replace(x) for x in new_outs]
old_outs = [] old_outs = []
if type(merging_node.outputs) not in (list, tuple): if type(merging_node.outputs) not in (list, tuple):
old_outs += [merging_node.outputs] old_outs += [merging_node.outputs]
...@@ -752,7 +751,7 @@ def cond_merge_random_op(fgraph, main_node): ...@@ -752,7 +751,7 @@ def cond_merge_random_op(fgraph, main_node):
else: else:
old_outs += proposal.outputs old_outs += proposal.outputs
pairs = list(zip(old_outs, new_outs)) pairs = list(zip(old_outs, new_outs))
main_outs = clone(main_node.outputs, replace=pairs) main_outs = clone_replace(main_node.outputs, replace=pairs)
return main_outs return main_outs
......
...@@ -43,9 +43,9 @@ __authors__ = ( ...@@ -43,9 +43,9 @@ __authors__ = (
"James Bergstra " "James Bergstra "
"Pascal Lamblin " "Pascal Lamblin "
"Arnaud Bergeron " "Arnaud Bergeron "
"PyMC Developers "
) )
__copyright__ = "(c) 2010, Universite de Montreal" __copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
from theano import configdefaults from theano import configdefaults
...@@ -55,5 +55,5 @@ configdefaults.add_scan_configvars() ...@@ -55,5 +55,5 @@ configdefaults.add_scan_configvars()
from theano.scan import opt from theano.scan import opt
from theano.scan.basic import scan from theano.scan.basic import scan
from theano.scan.checkpoints import scan_checkpoints from theano.scan.checkpoints import scan_checkpoints
from theano.scan.utils import clone, until from theano.scan.utils import until
from theano.scan.views import foldl, foldr, map, reduce from theano.scan.views import foldl, foldr, map, reduce
...@@ -19,7 +19,7 @@ from theano.compile import SharedVariable, ops ...@@ -19,7 +19,7 @@ from theano.compile import SharedVariable, ops
from theano.compile.function import function from theano.compile.function import function
from theano.compile.mode import Mode from theano.compile.mode import Mode
from theano.configdefaults import config from theano.configdefaults import config
from theano.graph.basic import Constant, Variable, graph_inputs from theano.graph.basic import Constant, Variable, clone_replace, graph_inputs
from theano.graph.fg import MissingInputError from theano.graph.fg import MissingInputError
from theano.graph.op import get_test_value from theano.graph.op import get_test_value
from theano.graph.utils import TestValueError from theano.graph.utils import TestValueError
...@@ -798,7 +798,7 @@ def scan( ...@@ -798,7 +798,7 @@ def scan(
if condition is not None: if condition is not None:
outputs.append(condition) outputs.append(condition)
fake_nonseqs = [x.type() for x in non_seqs] fake_nonseqs = [x.type() for x in non_seqs]
fake_outputs = utils.clone( fake_outputs = clone_replace(
outputs, replace=OrderedDict(zip(non_seqs, fake_nonseqs)) outputs, replace=OrderedDict(zip(non_seqs, fake_nonseqs))
) )
all_inputs = filter( all_inputs = filter(
...@@ -1025,7 +1025,7 @@ def scan( ...@@ -1025,7 +1025,7 @@ def scan(
else: else:
new_givens = givens new_givens = givens
new_outs = utils.clone(inner_outs, replace=new_givens) new_outs = clone_replace(inner_outs, replace=new_givens)
## ##
# Step 7. Create the Scan Op # Step 7. Create the Scan Op
......
...@@ -65,6 +65,7 @@ from theano.graph.basic import ( ...@@ -65,6 +65,7 @@ from theano.graph.basic import (
Apply, Apply,
Constant, Constant,
Variable, Variable,
clone_replace,
equal_computations, equal_computations,
graph_inputs, graph_inputs,
io_connection_pattern, io_connection_pattern,
...@@ -75,13 +76,7 @@ from theano.graph.toolbox import NoOutputFromInplace ...@@ -75,13 +76,7 @@ from theano.graph.toolbox import NoOutputFromInplace
from theano.link.c.basic import CLinker from theano.link.c.basic import CLinker
from theano.link.c.exceptions import MissingGXX from theano.link.c.exceptions import MissingGXX
from theano.link.utils import raise_with_op from theano.link.utils import raise_with_op
from theano.scan.utils import ( from theano.scan.utils import Validator, forced_replace, hash_listsDictsTuples, safe_new
Validator,
clone,
forced_replace,
hash_listsDictsTuples,
safe_new,
)
from theano.tensor.basic import as_tensor_variable from theano.tensor.basic import as_tensor_variable
from theano.tensor.opt import Shape_i from theano.tensor.opt import Shape_i
from theano.tensor.type import TensorType, integer_dtypes from theano.tensor.type import TensorType, integer_dtypes
...@@ -2485,7 +2480,7 @@ class Scan(Op): ...@@ -2485,7 +2480,7 @@ class Scan(Op):
replacement = inner_inp_mitmot[-replacement_idx] replacement = inner_inp_mitmot[-replacement_idx]
self.tap_array[idx] self.tap_array[idx]
new_inner_out_mitmot = clone( new_inner_out_mitmot = clone_replace(
new_inner_out_mitmot, replace=[(to_replace, replacement)] new_inner_out_mitmot, replace=[(to_replace, replacement)]
) )
......
...@@ -66,6 +66,7 @@ from theano.configdefaults import config ...@@ -66,6 +66,7 @@ from theano.configdefaults import config
from theano.graph.basic import ( from theano.graph.basic import (
Constant, Constant,
Variable, Variable,
clone_replace,
equal_computations, equal_computations,
graph_inputs, graph_inputs,
io_toposort, io_toposort,
...@@ -78,7 +79,6 @@ from theano.graph.optdb import EquilibriumDB, SequenceDB ...@@ -78,7 +79,6 @@ from theano.graph.optdb import EquilibriumDB, SequenceDB
from theano.graph.toolbox import ReplaceValidate from theano.graph.toolbox import ReplaceValidate
from theano.scan.op import Scan from theano.scan.op import Scan
from theano.scan.utils import ( from theano.scan.utils import (
clone,
compress_outs, compress_outs,
expand_empty, expand_empty,
reconstruct_graph, reconstruct_graph,
...@@ -231,7 +231,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -231,7 +231,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
nw_outer.extend(nw_outer_nonseq) nw_outer.extend(nw_outer_nonseq)
if len(nw_inner) != len(op_ins): if len(nw_inner) != len(op_ins):
op_outs = clone(op_outs, replace=givens) op_outs = clone_replace(op_outs, replace=givens)
nw_info = copy.deepcopy(op.info) nw_info = copy.deepcopy(op.info)
nw_info["n_seqs"] = nw_n_seqs nw_info["n_seqs"] = nw_n_seqs
# DEBUG CHECK # DEBUG CHECK
...@@ -403,7 +403,7 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -403,7 +403,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
nw_outer.append(repl_out) nw_outer.append(repl_out)
givens[to_repl] = repl_in givens[to_repl] = repl_in
op_outs = clone(clean_outputs, replace=givens) op_outs = clone_replace(clean_outputs, replace=givens)
op_ins = clean_inputs + nw_inner op_ins = clean_inputs + nw_inner
# Reconstruct node # Reconstruct node
...@@ -662,7 +662,7 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -662,7 +662,7 @@ class PushOutSeqScan(GlobalOptimizer):
givens[to_repl] = repl_in givens[to_repl] = repl_in
op_outs = clone(clean_outputs, replace=givens) op_outs = clone_replace(clean_outputs, replace=givens)
op_ins = nw_inner + clean_inputs op_ins = nw_inner + clean_inputs
# Reconstruct node # Reconstruct node
...@@ -2005,7 +2005,7 @@ def scan_merge_inouts(fgraph, node): ...@@ -2005,7 +2005,7 @@ def scan_merge_inouts(fgraph, node):
outer_inputs = a.outer_inputs outer_inputs = a.outer_inputs
info = a.info info = a.info
a_inner_outs = a.inner_outputs a_inner_outs = a.inner_outputs
inner_outputs = clone(a_inner_outs, replace=inp_equiv) inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv)
op = Scan(inner_inputs, inner_outputs, info) op = Scan(inner_inputs, inner_outputs, info)
outputs = op(*outer_inputs) outputs = op(*outer_inputs)
......
...@@ -7,9 +7,9 @@ __authors__ = ( ...@@ -7,9 +7,9 @@ __authors__ = (
"James Bergstra " "James Bergstra "
"Pascal Lamblin " "Pascal Lamblin "
"Arnaud Bergeron" "Arnaud Bergeron"
"PyMC Developers"
) )
__copyright__ = "(c) 2010, Universite de Montreal" __copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import copy import copy
...@@ -21,9 +21,14 @@ import numpy as np ...@@ -21,9 +21,14 @@ import numpy as np
from theano import scalar as ts from theano import scalar as ts
from theano import tensor as tt from theano import tensor as tt
from theano.compile.function.pfunc import rebuild_collect_shared
from theano.configdefaults import config from theano.configdefaults import config
from theano.graph.basic import Constant, Variable, equal_computations, graph_inputs from theano.graph.basic import (
Constant,
Variable,
clone_replace,
equal_computations,
graph_inputs,
)
from theano.graph.fg import FunctionGraph from theano.graph.fg import FunctionGraph
from theano.graph.op import get_test_value from theano.graph.op import get_test_value
from theano.graph.opt import TopoOptimizer, local_optimizer from theano.graph.opt import TopoOptimizer, local_optimizer
...@@ -180,54 +185,6 @@ def hash_listsDictsTuples(x): ...@@ -180,54 +185,6 @@ def hash_listsDictsTuples(x):
return hash_value return hash_value
def clone(output, replace=None, strict=True, share_inputs=True):
"""
Function that allows replacing subgraphs of a computational graph.
It returns a copy of the initial subgraph with the corresponding
substitutions.
Parameters
----------
output : Theano Variables (or Theano expressions)
Theano expression that represents the computational graph.
replace : dict
Dictionary describing which subgraphs should be replaced by what.
share_inputs : bool
If True, use the same inputs (and shared variables) as the original
graph. If False, clone them. Note that cloned shared variables still
use the same underlying storage, so they will always have the same
value.
"""
if isinstance(replace, dict):
items = list(replace.items())
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(
(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
)
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared(
output, [], tmp_replace, [], strict, share_inputs
)
# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(
_outs, [], new_replace, [], strict, share_inputs
)
return outs
def map_variables(replacer, graphs, additional_inputs=None): def map_variables(replacer, graphs, additional_inputs=None):
"""Construct new graphs based on 'graphs' with some variables replaced """Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'. according to 'replacer'.
...@@ -285,7 +242,7 @@ def map_variables(replacer, graphs, additional_inputs=None): ...@@ -285,7 +242,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
for input_, new_input in zip(inputs_, new_inputs) for input_, new_input in zip(inputs_, new_inputs)
if new_input is not input_ if new_input is not input_
] ]
graphs = clone(graphs, share_inputs=True, replace=replacements) graphs = clone_replace(graphs, share_inputs=True, replace=replacements)
inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs))) inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs)))
fg = FunctionGraph(inputs_, graphs, clone=False) fg = FunctionGraph(inputs_, graphs, clone=False)
...@@ -426,7 +383,9 @@ def _map_variables_inner( ...@@ -426,7 +383,9 @@ def _map_variables_inner(
replacements.extend(outer_to_inner.items()) replacements.extend(outer_to_inner.items())
(new_graph,) = clone([new_graph], share_inputs=True, replace=replacements) (new_graph,) = clone_replace(
[new_graph], share_inputs=True, replace=replacements
)
return new_graph return new_graph
new_inner_outputs = map_variables(inner_replacer, inner_outputs) new_inner_outputs = map_variables(inner_replacer, inner_outputs)
...@@ -908,7 +867,7 @@ def reconstruct_graph(inputs, outputs, tag=None): ...@@ -908,7 +867,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
if isinstance(inp, Constant): if isinstance(inp, Constant):
givens[inp] = inp.clone() givens[inp] = inp.clone()
nw_outputs = clone(outputs, replace=givens) nw_outputs = clone_replace(outputs, replace=givens)
return (nw_inputs, nw_outputs) return (nw_inputs, nw_outputs)
...@@ -1187,4 +1146,4 @@ def forced_replace(out, x, y): ...@@ -1187,4 +1146,4 @@ def forced_replace(out, x, y):
if len(to_replace) == 0: if len(to_replace) == 0:
return out return out
return clone(out, replace=to_replace) return clone_replace(out, replace=to_replace)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论