提交 6101a454 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1834 from delallea/minor

Renamed "copy_inputs" into "share_inputs" in clone
......@@ -15,6 +15,7 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import copy
import logging
import warnings
from itertools import izip
import numpy
......@@ -146,7 +147,8 @@ def get_updates_and_outputs(ls):
raise ValueError(error_msg)
def clone(output, replace=None, strict=True, copy_inputs=True):
DEPRECATED_ARG = object()
def clone(output, replace=None, strict=True, share_inputs=True, copy_inputs=DEPRECATED_ARG):
"""
Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding
......@@ -159,14 +161,24 @@ def clone(output, replace=None, strict=True, copy_inputs=True):
:type replace: dict
:param replace: dictionary describing which subgraphs should be
replaced by what
:type share_inputs: bool
:param share_inputs: If True, use the same inputs (and shared variables)
as the original graph. If False, clone them. Note that cloned
shared variables still use the same underlying storage, so they
will always have the same value.
"""
if copy_inputs is not DEPRECATED_ARG:
warnings.warn('In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`')
assert share_inputs # since we used `copy_inputs` we should have default value for `share_inputs`
share_inputs = copy_inputs
inps, outs, other_stuff = rebuild_collect_shared(output,
[],
replace,
[],
strict,
copy_inputs)
share_inputs)
return outs
......
......@@ -178,7 +178,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1,
replace=None,
strict=True,
copy_inputs=True)
share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp
......@@ -197,7 +197,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1,
replace=None,
strict=True,
copy_inputs=False)
share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp
......@@ -217,7 +217,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1,
replace={y: y2},
strict=True,
copy_inputs=True)
share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp
assert x in f2_inp
......@@ -236,7 +236,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1,
replace={y: y2},
strict=False,
copy_inputs=True)
share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp
assert x in f2_inp
......@@ -255,7 +255,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1,
replace={y: y2},
strict=True,
copy_inputs=False)
share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp
assert not x in f2_inp
......@@ -274,7 +274,7 @@ class TestScanUtils(unittest.TestCase):
f2 = scan_module.scan_utils.clone(f1,
replace={y: y2},
strict=False,
copy_inputs=False)
share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp
assert not x in f2_inp
......
......@@ -15,10 +15,10 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import copy
import logging
import warnings
from itertools import izip
import numpy
import warnings
import theano
from theano.compile.pfunc import rebuild_collect_shared
......@@ -157,10 +157,12 @@ def hash_listsDictsTuples(x):
return hash_value
DEPRECATED_ARG = object()
def clone(output,
replace=None,
strict=True,
copy_inputs=True):
share_inputs=True,
copy_inputs=DEPRECATED_ARG):
"""
Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding
......@@ -174,12 +176,17 @@ def clone(output,
:param replace: dictionary describing which subgraphs should be
replaced by what
:type copy_inputs: bool
:param copy_inputs: If True, use the same inputs (and shared variables)
:type share_inputs: bool
:param share_inputs: If True, use the same inputs (and shared variables)
as the original graph. If False, clone them. Note that cloned
shared variables still use the same underlying storage, so they
will always have the same value.
"""
if copy_inputs is not DEPRECATED_ARG:
warnings.warn('In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`')
assert share_inputs # since we used `copy_inputs` we should have default value for `share_inputs`
share_inputs = copy_inputs
if isinstance(replace, dict):
items = replace.items()
elif isinstance(replace, (list, tuple)):
......@@ -198,14 +205,15 @@ def clone(output,
tmp_replace,
[],
strict,
copy_inputs)
share_inputs)
# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(_outs,
[],
new_replace,
[],
strict,
copy_inputs)
share_inputs)
return outs
......
......@@ -1964,7 +1964,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1,
replace=None,
strict=True,
copy_inputs=True)
share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp
......@@ -1983,7 +1983,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1,
replace=None,
strict=True,
copy_inputs=False)
share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp
......@@ -2003,7 +2003,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1,
replace=OrderedDict([(y, y2)]),
strict=True,
copy_inputs=True)
share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp
assert x in f2_inp
......@@ -2022,7 +2022,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1,
replace=OrderedDict([(y, y2)]),
strict=False,
copy_inputs=True)
share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp
assert x in f2_inp
......@@ -2041,7 +2041,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1,
replace=[(y, y2)],
strict=True,
copy_inputs=False)
share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp
assert not x in f2_inp
......@@ -2060,7 +2060,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.clone(f1,
replace=[(y, y2)],
strict=False,
copy_inputs=False)
share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp
assert not x in f2_inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论