提交 f03830f8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

add tests for ticket 766 and bug reported by Ilya

上级 e878960d
......@@ -2341,6 +2341,68 @@ class T_Scan(unittest.TestCase):
assert len(lssc) == 2
def test_bug_ticket766(self):
coefficients = theano.tensor.vector("coefficients")
x = tensor.scalar("x"); max_coefficients_supported = 10000
# Generate the components of the polynomial
full_range=theano.tensor.arange(max_coefficients_supported)
components, updates = theano.scan(fn=lambda coeff, power, free_var:
coeff * (free_var ** power),
sequences=[coefficients, full_range],
non_sequences=x)
polynomial1 = components.sum()
polynomial2, updates = theano.scan(fn=lambda coeff, power, prev, free_var:
prev + coeff * (free_var ** power),
outputs_info=theano.tensor.constant(0, dtype='floatX'),
sequences=[coefficients, full_range],
non_sequences=x)
polynomial3, updates = theano.scan(fn=lambda coeff, power, prev, free_var:
prev + coeff * (free_var ** power),
outputs_info=0.,
sequences=[coefficients, full_range],
non_sequences=x)
calculate_polynomial = theano.function(inputs=[coefficients, x],
outputs=[polynomial1, polynomial2[-1]])
test_coeff = numpy.asarray([1, 0, 2], dtype=numpy.float32)
# This will be tested by DEBUG_MODE
calculate_polynomial(test_coeff, 3)
# 19.0
def test_bugFunctioProvidesIntermediateNodesAsInputs(self):
# This is a bug recently reported by Ilya
# made it CPU friendly
V = tensor.ftensor3('INPUT')
orig = tensor.fmatrix('PARAM')
# = gpu_from_host(orig) # <-- this doesn't work
W = orig + 2 # <-- has same effect but it works on CPU as well
#W = T.fmatrix('PARAM') # <-- this line works
def one_step(v,W):
o = v + 1 + W.sum() # <-- this doesn't work
#o = v + 1 # <-- this line works
return o
OS, updates = theano.scan(
fn = one_step,
sequences = V,
outputs_info = [None],
non_sequences = [W]
)
O = OS.sum() + W.sum()
# This bug manifests itself by not allowing the function to compile,
# so if it compiles it means the test pass
f = theano.function([V, W], O)
def test_while2(self):
x = tensor.vector('x')
def lambda_fn(x_t):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论