提交 6101a454 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1834 from delallea/minor

Renamed "copy_inputs" into "share_inputs" in clone
...@@ -15,6 +15,7 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>" ...@@ -15,6 +15,7 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import copy import copy
import logging import logging
import warnings
from itertools import izip from itertools import izip
import numpy import numpy
...@@ -146,7 +147,8 @@ def get_updates_and_outputs(ls): ...@@ -146,7 +147,8 @@ def get_updates_and_outputs(ls):
raise ValueError(error_msg) raise ValueError(error_msg)
def clone(output, replace=None, strict=True, copy_inputs=True): DEPRECATED_ARG = object()
def clone(output, replace=None, strict=True, share_inputs=True, copy_inputs=DEPRECATED_ARG):
""" """
Function that allows replacing subgraphs of a computational Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding graph. It returns a copy of the initial subgraph with the corresponding
...@@ -159,14 +161,24 @@ def clone(output, replace=None, strict=True, copy_inputs=True): ...@@ -159,14 +161,24 @@ def clone(output, replace=None, strict=True, copy_inputs=True):
:type replace: dict :type replace: dict
:param replace: dictionary describing which subgraphs should be :param replace: dictionary describing which subgraphs should be
replaced by what replaced by what
:type share_inputs: bool
:param share_inputs: If True, use the same inputs (and shared variables)
as the original graph. If False, clone them. Note that cloned
shared variables still use the same underlying storage, so they
will always have the same value.
""" """
if copy_inputs is not DEPRECATED_ARG:
warnings.warn('In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`')
assert share_inputs # since we used `copy_inputs` we should have default value for `share_inputs`
share_inputs = copy_inputs
inps, outs, other_stuff = rebuild_collect_shared(output, inps, outs, other_stuff = rebuild_collect_shared(output,
[], [],
replace, replace,
[], [],
strict, strict,
copy_inputs) share_inputs)
return outs return outs
......
...@@ -178,7 +178,7 @@ class TestScanUtils(unittest.TestCase): ...@@ -178,7 +178,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1, f2 = scan_module.scan_utils.clone(f1,
replace=None, replace=None,
strict=True, strict=True,
copy_inputs=True) share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp assert z in f2_inp
...@@ -197,7 +197,7 @@ class TestScanUtils(unittest.TestCase): ...@@ -197,7 +197,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1, f2 = scan_module.scan_utils.clone(f1,
replace=None, replace=None,
strict=True, strict=True,
copy_inputs=False) share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp assert not z in f2_inp
...@@ -217,7 +217,7 @@ class TestScanUtils(unittest.TestCase): ...@@ -217,7 +217,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1, f2 = scan_module.scan_utils.clone(f1,
replace={y: y2}, replace={y: y2},
strict=True, strict=True,
copy_inputs=True) share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
...@@ -236,7 +236,7 @@ class TestScanUtils(unittest.TestCase): ...@@ -236,7 +236,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1, f2 = scan_module.scan_utils.clone(f1,
replace={y: y2}, replace={y: y2},
strict=False, strict=False,
copy_inputs=True) share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
...@@ -255,7 +255,7 @@ class TestScanUtils(unittest.TestCase): ...@@ -255,7 +255,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1, f2 = scan_module.scan_utils.clone(f1,
replace={y: y2}, replace={y: y2},
strict=True, strict=True,
copy_inputs=False) share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp assert not z in f2_inp
assert not x in f2_inp assert not x in f2_inp
...@@ -274,7 +274,7 @@ class TestScanUtils(unittest.TestCase): ...@@ -274,7 +274,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1, f2 = scan_module.scan_utils.clone(f1,
replace={y: y2}, replace={y: y2},
strict=False, strict=False,
copy_inputs=False) share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp assert not z in f2_inp
assert not x in f2_inp assert not x in f2_inp
......
...@@ -15,10 +15,10 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>" ...@@ -15,10 +15,10 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import copy import copy
import logging import logging
import warnings
from itertools import izip from itertools import izip
import numpy import numpy
import warnings
import theano import theano
from theano.compile.pfunc import rebuild_collect_shared from theano.compile.pfunc import rebuild_collect_shared
...@@ -157,10 +157,12 @@ def hash_listsDictsTuples(x): ...@@ -157,10 +157,12 @@ def hash_listsDictsTuples(x):
return hash_value return hash_value
DEPRECATED_ARG = object()
def clone(output, def clone(output,
replace=None, replace=None,
strict=True, strict=True,
copy_inputs=True): share_inputs=True,
copy_inputs=DEPRECATED_ARG):
""" """
Function that allows replacing subgraphs of a computational Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding graph. It returns a copy of the initial subgraph with the corresponding
...@@ -174,12 +176,17 @@ def clone(output, ...@@ -174,12 +176,17 @@ def clone(output,
:param replace: dictionary describing which subgraphs should be :param replace: dictionary describing which subgraphs should be
replaced by what replaced by what
:type copy_inputs: bool :type share_inputs: bool
:param copy_inputs: If True, use the same inputs (and shared variables) :param share_inputs: If True, use the same inputs (and shared variables)
as the original graph. If False, clone them. Note that cloned as the original graph. If False, clone them. Note that cloned
shared variables still use the same underlying storage, so they shared variables still use the same underlying storage, so they
will always have the same value. will always have the same value.
""" """
if copy_inputs is not DEPRECATED_ARG:
warnings.warn('In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`')
assert share_inputs # since we used `copy_inputs` we should have default value for `share_inputs`
share_inputs = copy_inputs
if isinstance(replace, dict): if isinstance(replace, dict):
items = replace.items() items = replace.items()
elif isinstance(replace, (list, tuple)): elif isinstance(replace, (list, tuple)):
...@@ -198,14 +205,15 @@ def clone(output, ...@@ -198,14 +205,15 @@ def clone(output,
tmp_replace, tmp_replace,
[], [],
strict, strict,
copy_inputs) share_inputs)
# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(_outs, _, outs, _ = rebuild_collect_shared(_outs,
[], [],
new_replace, new_replace,
[], [],
strict, strict,
copy_inputs) share_inputs)
return outs return outs
......
...@@ -1964,7 +1964,7 @@ class T_Scan(unittest.TestCase): ...@@ -1964,7 +1964,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1, f2 = theano.clone(f1,
replace=None, replace=None,
strict=True, strict=True,
copy_inputs=True) share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp assert z in f2_inp
...@@ -1983,7 +1983,7 @@ class T_Scan(unittest.TestCase): ...@@ -1983,7 +1983,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1, f2 = theano.clone(f1,
replace=None, replace=None,
strict=True, strict=True,
copy_inputs=False) share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp assert not z in f2_inp
...@@ -2003,7 +2003,7 @@ class T_Scan(unittest.TestCase): ...@@ -2003,7 +2003,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1, f2 = theano.clone(f1,
replace=OrderedDict([(y, y2)]), replace=OrderedDict([(y, y2)]),
strict=True, strict=True,
copy_inputs=True) share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
...@@ -2022,7 +2022,7 @@ class T_Scan(unittest.TestCase): ...@@ -2022,7 +2022,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1, f2 = theano.clone(f1,
replace=OrderedDict([(y, y2)]), replace=OrderedDict([(y, y2)]),
strict=False, strict=False,
copy_inputs=True) share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
...@@ -2041,7 +2041,7 @@ class T_Scan(unittest.TestCase): ...@@ -2041,7 +2041,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1, f2 = theano.clone(f1,
replace=[(y, y2)], replace=[(y, y2)],
strict=True, strict=True,
copy_inputs=False) share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp assert not z in f2_inp
assert not x in f2_inp assert not x in f2_inp
...@@ -2060,7 +2060,7 @@ class T_Scan(unittest.TestCase): ...@@ -2060,7 +2060,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1, f2 = theano.clone(f1,
replace=[(y, y2)], replace=[(y, y2)],
strict=False, strict=False,
copy_inputs=False) share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp assert not z in f2_inp
assert not x in f2_inp assert not x in f2_inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论