提交 a788e179 authored 作者: lamblin's avatar lamblin

Merge pull request #1238 from pascanur/fix_clone

fix to clone functionality of theano
...@@ -169,14 +169,33 @@ def clone(output, ...@@ -169,14 +169,33 @@ def clone(output,
shared variables still use the same underlying storage, so they shared variables still use the same underlying storage, so they
will always have the same value. will always have the same value.
""" """
if isinstance(replace, dict):
items = replace.items()
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(("replace is neither a dictionary, list, "
"tuple or None ! The value provided is %s,"
"of type %s")%(str(replace), str(type(replace))))
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace,
items)]
_, _outs, _ = rebuild_collect_shared(output,
[],
tmp_replace,
[],
strict,
copy_inputs)
_, outs, _ = rebuild_collect_shared(_outs,
[],
new_replace,
[],
strict,
copy_inputs)
inps, outs, other_stuff = rebuild_collect_shared(output,
[],
replace,
[],
strict,
copy_inputs
)
return outs return outs
......
...@@ -3427,6 +3427,23 @@ class T_Scan(unittest.TestCase): ...@@ -3427,6 +3427,23 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(outs[2], v_w + 3) assert numpy.allclose(outs[2], v_w + 3)
assert numpy.allclose(sh.get_value(), v_w + 4) assert numpy.allclose(sh.get_value(), v_w + 4)
def test_clone(self):
def test(x, y, mention_y):
if mention_y:
d = 0.1 + 0 * y
else:
d = 0.1
out = theano.clone(y, replace={x:x + d})
theano.printing.debugprint(out)
return theano.function([], out)()
x = theano.shared(numpy.asarray(0., dtype=theano.config.floatX))
assert numpy.allclose(test(x, tensor.sum((x+1)**2), mention_y=False),
1.21000003815)
assert numpy.allclose(test(x, tensor.sum((x+1)**2), mention_y=True),
1.21000003815)
def test_speed(): def test_speed():
# #
# This function prints out the speed of very simple recurrent # This function prints out the speed of very simple recurrent
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论