提交 6b72779b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

mixt commit for improving tests

上级 babdf02d
......@@ -215,7 +215,7 @@ def canonical_arguments(sequences,
orig_input = orig_input[::-1]
if n_steps is not None:
orig_input = tensor.switch(negative_n_steps, orig_input[::-1],
org_input)
orig_input)
for k in input['taps']:
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
......
import cPickle
import numpy
import unittest
import theano
from theano.compile.pfunc import rebuild_collect_shared
import theano.sandbox.scan_module as scan_module
if theano.config.mode == 'FAST_COMPILE':
mode_with_opt = theano.compile.mode.get_mode('FAST_RUN')
......@@ -160,3 +162,120 @@ def grab_scan_node(output):
return None
else:
return rval
class TestScanUtils(unittest.TestCase):
def test_cloning_no_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = theano.tensor.vector('x')
y = theano.tensor.vector('y')
z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = scan_module.scan_utils.clone(f1,
replace=None,
strict=True,
copy_inputs=True)
f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y in f2_inp
def test_cloning_no_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = theano.tensor.vector('x')
y = theano.tensor.vector('y')
z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = scan_module.scan_utils.clone(f1,
replace=None,
strict=True,
copy_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp
assert not x in f2_inp
assert not y in f2_inp
def test_cloning_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = theano.tensor.vector('x')
y = theano.tensor.vector('y')
y2 = theano.tensor.vector('y2')
z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = scan_module.scan_utils.clone(f1,
replace={y: y2},
strict=True,
copy_inputs=True)
f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_not_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = theano.tensor.vector('x')
y = theano.tensor.fvector('y')
y2 = theano.tensor.dvector('y2')
z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = scan_module.scan_utils.clone(f1,
replace={y: y2},
strict=False,
copy_inputs=True)
f2_inp = theano.gof.graph.inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = theano.tensor.vector('x')
y = theano.tensor.vector('y')
y2 = theano.tensor.vector('y2')
z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = scan_module.scan_utils.clone(f1,
replace={y: y2},
strict=True,
copy_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp
assert not x in f2_inp
assert not y2 in f2_inp
def test_cloning_replace_not_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = theano.tensor.vector('x')
y = theano.tensor.fvector('y')
y2 = theano.tensor.dvector('y2')
z = theano.shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = scan_module.scan_utils.clone(f1,
replace={y: y2},
strict=False,
copy_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
assert not z in f2_inp
assert not x in f2_inp
assert not y2 in f2_inp
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论