提交 a788e179 authored 作者: lamblin's avatar lamblin

Merge pull request #1238 from pascanur/fix_clone

fix to clone functionality of theano
......@@ -169,14 +169,33 @@ def clone(output,
shared variables still use the same underlying storage, so they
will always have the same value.
"""
if isinstance(replace, dict):
items = replace.items()
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(("replace is neither a dictionary, list, "
"tuple or None ! The value provided is %s,"
"of type %s")%(str(replace), str(type(replace))))
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace,
items)]
_, _outs, _ = rebuild_collect_shared(output,
[],
tmp_replace,
[],
strict,
copy_inputs)
_, outs, _ = rebuild_collect_shared(_outs,
[],
new_replace,
[],
strict,
copy_inputs)
inps, outs, other_stuff = rebuild_collect_shared(output,
[],
replace,
[],
strict,
copy_inputs
)
return outs
......
......@@ -3427,6 +3427,23 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(outs[2], v_w + 3)
assert numpy.allclose(sh.get_value(), v_w + 4)
def test_clone(self):
def test(x, y, mention_y):
if mention_y:
d = 0.1 + 0 * y
else:
d = 0.1
out = theano.clone(y, replace={x:x + d})
theano.printing.debugprint(out)
return theano.function([], out)()
x = theano.shared(numpy.asarray(0., dtype=theano.config.floatX))
assert numpy.allclose(test(x, tensor.sum((x+1)**2), mention_y=False),
1.21000003815)
assert numpy.allclose(test(x, tensor.sum((x+1)**2), mention_y=True),
1.21000003815)
def test_speed():
#
# This function prints out the speed of very simple recurrent
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论