提交 27f80b39 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Refactor tests.gof.test_lazy and enable an incorrectly named test

上级 50ab0430
from copy import deepcopy
import pytest
import numpy as np
import theano
from theano.gof.op import PureOp
from theano.gof import Apply, generic
import theano.tensor as tt
from copy import deepcopy
from theano import function, Mode
from theano.gof import Apply, generic
from theano.gof.op import PureOp
from theano.ifelse import ifelse
import theano.tensor as tt
class IfElseIfElseIf(PureOp):
......@@ -88,16 +91,17 @@ class IfElseIfElseIf(PureOp):
return thunk
class NotImplementedOp(PureOp):
class E(Exception):
pass
class NotImplementedOpException(Exception):
pass
class NotImplementedOp(PureOp):
def make_node(self, x):
return Apply(self, [x], [x.type()])
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl):
def thunk():
raise self.E()
raise NotImplementedOpException()
thunk.lazy = False
return thunk
......@@ -109,13 +113,17 @@ def test_ifelse():
c = generic()
notimpl = NotImplementedOp()
lazys = [True]
# We need lazy to end up being True for this test.
if theano.config.vm.lazy in [True, None]:
lazys = [True, None]
cloops = [True, False]
if theano.config.cxx == "":
cloops = [False]
for cloop in cloops:
for lazy in lazys:
linker = theano.gof.vm.VM_Linker(use_cloop=cloop, lazy=lazy)
......@@ -125,21 +133,13 @@ def test_ifelse():
mode=Mode(linker=linker, optimizer="fast_run"),
)
try:
# print "case 1"
with pytest.raises(NotImplementedOpException):
f(1, "a", "b")
assert False
except NotImplementedOp.E:
pass
# print "... passed"
# print "case 2"
# print f(0, 'a', 'b')
assert f(0, "a", "b") == "b"
# print "... passed"
def more_complex_test():
def test_nested():
notimpl = NotImplementedOp()
ifelseifelseif = IfElseIfElseIf()
......@@ -156,18 +156,11 @@ def more_complex_test():
t4 = ifelseifelseif(tt.eq(x1, x2), x1, tt.eq(x1, 5), x2, c2, t3, t3 + 0.5)
t4.name = "t4"
f = function([c1, c2, x1, x2], t4, mode=Mode(linker="vm", optimizer="fast_run"))
if theano.config.vm.lazy is False:
try:
f(1, 0, np.array(10, dtype=x1.dtype), 0)
assert False
except NotImplementedOp.E:
pass
else:
print(f(1, 0, np.array(10, dtype=x1.dtype), 0))
assert f(1, 0, np.array(10, dtype=x1.dtype), 0) == 20.5
print("... passed")
if __name__ == "__main__":
more_complex_test()
linker = theano.gof.vm.VM_Linker(lazy=False)
f = function([c1, c2, x1, x2], t4, mode=Mode(linker=linker, optimizer="fast_run"))
with pytest.raises(NotImplementedOpException):
f(1, 0, np.array(10, dtype=x1.dtype), 0)
linker = theano.gof.vm.VM_Linker(lazy=True)
f = function([c1, c2, x1, x2], t4, mode=Mode(linker=linker, optimizer="fast_run"))
assert f(1, 0, np.array(10, dtype=x1.dtype), 0) == 20.5
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论