提交 83794e06 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Finish tests for compute_test_value when there is only C or only Py code.

上级 4e74b47b
......@@ -5,9 +5,12 @@ import unittest
import theano
from theano import config
from theano import scalar
from theano import tensor as T
from theano.tensor.basic import _allclose
from theano.gof import Apply, Op
from theano.gof import utils
from theano.scan_module import scan
from theano.tensor.basic import _allclose
class TestComputeTestValue(unittest.TestCase):
......@@ -277,28 +280,85 @@ class TestComputeTestValue(unittest.TestCase):
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_no_c_code(self):
class IncOnePython(Op):
"""An Op with only a Python (perform) implementation"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, input):
input = scalar.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def perform(self, node, inputs, outputs):
input, = inputs
output, = outputs
output[0] = input + 1
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'raise'
i = scalar.int32('i')
i.tag.test_value = 3
o = IncOnePython()(i)
# Check that the c_code function is not implemented
self.assertRaises((NotImplementedError, utils.MethodNotDefined),
o.owner.op.c_code,
o.owner, 'o', ['x'], 'z', {'fail': ''})
assert hasattr(o.tag, 'test_value')
assert o.tag.test_value == 4
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_no_perform(self):
pass
class IncOneC(Op):
"""An Op with only a C (c_code) implementation"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, input):
input = scalar.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def c_code(self, node, name, inputs, outputs, sub):
x, = inputs
z, = outputs
return "%(z)s = %(x)s + 1;" % locals()
def test_no_c_code(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'raise'
# int_div has no C code for the moment
# If that changes, use another op with no C code here
i = T.iscalar('i')
j = T.iscalar('j')
i = scalar.int32('i')
i.tag.test_value = 3
j.tag.test_value = 2
o = i // j
o = IncOneC()(i)
# Check that the c_code function is not implemented
self.assertRaises(NotImplementedError, o.owner.op.c_code,
o.owner, 'o', ['x', 'y'], 'z', {'fail': ''})
self.assertRaises((NotImplementedError, utils.MethodNotDefined),
o.owner.op.perform,
o.owner, 0, [None])
assert hasattr(o.tag, 'test_value')
assert o.tag.test_value == 1
assert o.tag.test_value == 4
finally:
theano.config.compute_test_value = orig_compute_test_value
pass
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论