提交 6812b4bc authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Declare classes outside of test class so the Op can be pickled

上级 b0de5da7
...@@ -14,6 +14,30 @@ from theano.scan_module import scan ...@@ -14,6 +14,30 @@ from theano.scan_module import scan
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
# Used in TestComputeTestValue.test_no_perform
class IncOneC(Op):
"""An Op with only a C (c_code) implementation"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, input):
input = scalar.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
x, = inputs
z, = outputs
return "%(z)s = %(x)s + 1;" % locals()
class TestComputeTestValue(unittest.TestCase): class TestComputeTestValue(unittest.TestCase):
def test_variable_only(self): def test_variable_only(self):
...@@ -338,28 +362,6 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -338,28 +362,6 @@ class TestComputeTestValue(unittest.TestCase):
def test_no_perform(self): def test_no_perform(self):
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.") raise SkipTest("G++ not available, so we need to skip this test.")
class IncOneC(Op):
"""An Op with only a C (c_code) implementation"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, input):
input = scalar.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
x, = inputs
z, = outputs
return "%(z)s = %(x)s + 1;" % locals()
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
...@@ -368,6 +370,8 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -368,6 +370,8 @@ class TestComputeTestValue(unittest.TestCase):
i = scalar.int32('i') i = scalar.int32('i')
i.tag.test_value = 3 i.tag.test_value = 3
# Class IncOneC is defined outside of the TestComputeTestValue
# so it can be pickled and unpickled
o = IncOneC()(i) o = IncOneC()(i)
# Check that the perform function is not implemented # Check that the perform function is not implemented
......
...@@ -879,15 +879,8 @@ class T_using_gpu(unittest.TestCase): ...@@ -879,15 +879,8 @@ class T_using_gpu(unittest.TestCase):
for x in f.maker.fgraph.toposort()]) for x in f.maker.fgraph.toposort()])
class T_fibby(unittest.TestCase): # Used in T_fibby
## All tests here belong to class Fibby(theano.Op):
## http://deeplearning.net/software/theano/extending/fibby.html
## Theano/doc/extending/fibby.txt
## Any change you do here also add it to the tutorial !
def test_fibby_1(self):
class Fibby(theano.Op):
""" """
An arbitrarily generalized Fibbonacci sequence An arbitrarily generalized Fibbonacci sequence
...@@ -935,6 +928,17 @@ class T_fibby(unittest.TestCase): ...@@ -935,6 +928,17 @@ class T_fibby(unittest.TestCase):
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
class T_fibby(unittest.TestCase):
## All tests here belong to
## http://deeplearning.net/software/theano/extending/fibby.html
## Theano/doc/extending/fibby.txt
## Any change you do here also add it to the tutorial !
def test_fibby_1(self):
# The definition of class Fibby is done outside of the test,
# so the object can be pickled.
fibby = Fibby() fibby = Fibby()
from theano.tensor.opt import (get_scalar_constant_value, from theano.tensor.opt import (get_scalar_constant_value,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论