提交 4185f225 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add test for make_thunk

上级 dbf8b887
from copy import copy from copy import copy
import unittest
import numpy
import theano import theano
...@@ -6,6 +9,8 @@ from theano.gof.op import * ...@@ -6,6 +9,8 @@ from theano.gof.op import *
from theano.gof.type import Type, Generic from theano.gof.type import Type, Generic
from theano.gof.graph import Apply, Variable from theano.gof.graph import Apply, Variable
from theano import scalar
def as_variable(x): def as_variable(x):
assert isinstance(x, Variable) assert isinstance(x, Variable)
return x return x
...@@ -86,6 +91,96 @@ class TestOp: ...@@ -86,6 +91,96 @@ class TestOp:
rval = f() rval = f()
assert rval == 'test Op no input' assert rval == 'test Op no input'
class TestMakeThunk(unittest.TestCase):
def test_no_c_code(self):
class IncOnePython(Op):
"""An Op with only a Python (perform) implementation"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, input):
input = scalar.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def perform(self, node, inputs, outputs):
input, = inputs
output, = outputs
output[0] = input + 1
i = scalar.int32('i')
o = IncOnePython()(i)
# Check that the c_code function is not implemented
self.assertRaises((NotImplementedError, utils.MethodNotDefined),
o.owner.op.c_code,
o.owner, 'o', ['x'], 'z', {'fail': ''})
storage_map = {
i: [numpy.int32(3)],
o: [None]}
compute_map = {
i: [True],
o: [False]}
thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map,
no_recycling=[])
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
def test_no_perform(self):
class IncOneC(Op):
"""An Op with only a C (c_code) implementation"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, input):
input = scalar.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def c_code(self, node, name, inputs, outputs, sub):
x, = inputs
z, = outputs
return "%(z)s = %(x)s + 1;" % locals()
i = scalar.int32('i')
o = IncOneC()(i)
# Check that the perform function is not implemented
self.assertRaises((NotImplementedError, utils.MethodNotDefined),
o.owner.op.perform,
o.owner, 0, [None])
storage_map = {
i: [numpy.int32(3)],
o: [None]}
compute_map = {
i: [True],
o: [False]}
thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map,
no_recycling=[])
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论