提交 b7294365 authored 作者: lamblin's avatar lamblin

Merge pull request #1376 from nouiz/fix_tests

Fix tests
......@@ -2,6 +2,7 @@ from copy import deepcopy
import numpy
import theano
from theano.gof.op import PureOp
from theano.gof import Apply, generic, Container
from theano.gof.link import LocalLinker, map_storage, add_clear_storage
......@@ -13,15 +14,16 @@ import theano.tensor as T
class IfElseIfElseIf(PureOp):
def __init__(self, inplace=False):
self.inplace=inplace # check destroyhandler and others to ensure that a view_map with
# check destroyhandler and others to ensure that a view_map with
self.inplace = inplace
#multiple inputs can work
assert not self.inplace
def make_node(self, c1, t1, c2,t2,c3,t3,f3):
def make_node(self, c1, t1, c2, t2, c3, t3, f3):
assert t1.type == f3.type
assert t2.type == t3.type
assert t3.type == f3.type
return Apply(self, [c1,t1,c2,t2,c3,t3,f3], [t1.type()])
return Apply(self, [c1, t1, c2, t2, c3, t3, f3], [t1.type()])
def make_thunk(self, node, storage_map, compute_map, no_recycling):
......@@ -41,8 +43,9 @@ class IfElseIfElseIf(PureOp):
if not input_computed[1][0]:
return [1]
else:
output_computed[0][0]=1
output_registers[0][0]=outtype.filter(deepcopy(input_registers[1][0]))
output_computed[0][0] = 1
output_registers[0][0] = outtype.filter(
deepcopy(input_registers[1][0]))
return []
else:
if not input_computed[2][0]:
......@@ -54,7 +57,8 @@ class IfElseIfElseIf(PureOp):
return [3]
else:
output_computed[0][0] = 1
output_registers[0][0] = outtype.filter(deepcopy(input_registers[3][0]))
output_registers[0][0] = outtype.filter(
deepcopy(input_registers[3][0]))
return []
else:
if not input_computed[4][0]:
......@@ -66,30 +70,33 @@ class IfElseIfElseIf(PureOp):
return [5]
else:
output_computed[0][0] = 1
output_registers[0][0] = outtype.filter(deepcopy(input_registers[5][0]))
output_registers[0][0] = outtype.filter(
deepcopy(input_registers[5][0]))
return []
else:
if not input_computed[6][0]:
return [6]
else:
output_computed[0][0] = 1
output_registers[0][0] = outtype.filter(deepcopy(input_registers[6][0]))
output_registers[0][0] = outtype.filter(
deepcopy(input_registers[6][0]))
return []
thunk.lazy = True
return thunk
class NotImplementedOp(PureOp):
class E(Exception): pass
class E(Exception):
pass
def make_node(self, x):
return Apply(self, [x], [x.type()])
def make_thunk(self, node, storage_map, compute_map, no_recycling):
def thunk():
raise self.E()
thunk.lazy=False
thunk.lazy = False
return thunk
......@@ -99,22 +106,28 @@ def test_ifelse():
c = generic()
notimpl = NotImplementedOp()
f = function([a,b,c], ifelse(a, notimpl(b), c),
mode=Mode(linker='vm', optimizer='fast_run'))
try:
print "case 1"
f( 1, 'a', 'b')
assert False
except NotImplementedOp.E:
pass
print "... passed"
print "case 2"
print f( 0, 'a', 'b')
assert f( 0, 'a', 'b') == 'b'
print "... passed"
lazys = [True]
# We need lazy to end up being True for this test.
if theano.config.vm.lazy in [True, None]:
lazys = [True, None]
for cloop in [True, False]:
for lazy in lazys:
linker = theano.gof.vm.VM_Linker(use_cloop=cloop, lazy=lazy)
f = function([a, b, c], ifelse(a, notimpl(b), c),
mode=Mode(linker=linker, optimizer='fast_run'))
try:
print "case 1"
f(1, 'a', 'b')
assert False
except NotImplementedOp.E:
pass
print "... passed"
print "case 2"
print f(0, 'a', 'b')
assert f(0, 'a', 'b') == 'b'
print "... passed"
def more_complex_test():
......@@ -125,19 +138,26 @@ def more_complex_test():
x2 = T.scalar('x2')
c1 = T.scalar('c1')
c2 = T.scalar('c2')
t1 = ifelse(c1,x1,notimpl(x2))
t1 = ifelse(c1, x1, notimpl(x2))
t1.name = 't1'
t2 = t1*10
t2 = t1 * 10
t2.name = 't2'
t3 = ifelse(c2,t2, x1+t1)
t3 = ifelse(c2, t2, x1 + t1)
t3.name = 't3'
t4 = ifelseifelseif(T.eq(x1,x2), x1, T.eq(x1,5), x2, c2, t3, t3+0.5)
t4 = ifelseifelseif(T.eq(x1, x2), x1, T.eq(x1, 5), x2, c2, t3, t3 + 0.5)
t4.name = 't4'
f = function([c1,c2,x1,x2], t4, mode=Mode(linker='vm', optimizer='fast_run'))
print f(1, 0, numpy.array(10,dtype=x1.dtype),0)
assert f(1,0,numpy.array(10,dtype=x1.dtype),0) == 20.5
f = function([c1, c2, x1, x2], t4, mode=Mode(linker='vm',
optimizer='fast_run'))
if theano.config.vm.lazy is False:
try:
f(1, 0, numpy.array(10, dtype=x1.dtype), 0)
assert False
except NotImplementedOp.E:
pass
else:
print f(1, 0, numpy.array(10, dtype=x1.dtype), 0)
assert f(1, 0, numpy.array(10, dtype=x1.dtype), 0) == 20.5
print '... passed'
if __name__ == '__main__':
......
......@@ -36,7 +36,7 @@ def filter_vm_lazy(val):
return False
elif val == 'True' or val is True:
return True
elif val == 'None':
elif val == 'None' or val is None:
return None
else:
raise ValueError('Valid values for an vm.lazy parameter '
......
import copy
import gc
import numpy as np
......@@ -18,6 +19,11 @@ if theano.config.mode == 'FAST_COMPILE':
else:
mode_with_gpu = theano.compile.mode.get_default_mode().including('gpu')
# The GC need to be enabled for those tests to work correctly.
if not getattr(mode_with_gpu.linker, 'allow_gc', False):
mode_with_gpu.linker = copy.copy(mode_with_gpu.linker)
mode_with_gpu.linker.allow_gc = True
def freemem(extra_alloc=0):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论