提交 ded724a9 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix to clone functionality of theano

上级 116ad0f3
...@@ -169,14 +169,24 @@ def clone(output, ...@@ -169,14 +169,24 @@ def clone(output,
shared variables still use the same underlying storage, so they shared variables still use the same underlying storage, so they
will always have the same value. will always have the same value.
""" """
items = replace.items()
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace,
items)]
_, _outs, _ = rebuild_collect_shared(output,
[],
tmp_replace,
[],
strict,
copy_inputs)
_, outs, _ = rebuild_collect_shared(_outs,
[],
new_replace,
[],
strict,
copy_inputs)
inps, outs, other_stuff = rebuild_collect_shared(output,
[],
replace,
[],
strict,
copy_inputs
)
return outs return outs
......
...@@ -3427,6 +3427,23 @@ class T_Scan(unittest.TestCase): ...@@ -3427,6 +3427,23 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(outs[2], v_w + 3) assert numpy.allclose(outs[2], v_w + 3)
assert numpy.allclose(sh.get_value(), v_w + 4) assert numpy.allclose(sh.get_value(), v_w + 4)
def test_clone(self):
def test(x, y, mention_y):
if mention_y:
d = 0.1 + 0 * y
else:
d = 0.1
out = theano.clone(y, replace={x:x + d})
theano.printing.debugprint(out)
return theano.function([], out)()
x = theano.shared(numpy.asarray(0., dtype=theano.config.floatX))
assert numpy.allclose(test(x, tensor.sum((x+1)**2), mention_y=False),
1.21000003815)
assert numpy.allclose(test(x, tensor.sum((x+1)**2), mention_y=True),
1.21000003815)
def test_speed(): def test_speed():
# #
# This function prints out the speed of very simple recurrent # This function prints out the speed of very simple recurrent
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论