提交 99536643 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Added a test case for the situation of an Op has no input

上级 af520b15
from copy import copy from copy import copy
import theano
from theano.gof.op import * from theano.gof.op import *
from theano.gof.type import Type, Generic from theano.gof.type import Type, Generic
from theano.gof.graph import Apply, Variable from theano.gof.graph import Apply, Variable
...@@ -23,6 +25,11 @@ class MyType(Type): ...@@ -23,6 +25,11 @@ class MyType(Type):
def __repr__(self): def __repr__(self):
return str(self.thingy) return str(self.thingy)
def filter(self, x, strict=True, allow_downcast=None):
# Dummy filter: we want this type to represent strings that
# start with `self.thingy`.
assert isinstance(x, str) and x.startswith(self.thingy)
return x
class MyOp(Op): class MyOp(Op):
...@@ -37,6 +44,23 @@ class MyOp(Op): ...@@ -37,6 +44,23 @@ class MyOp(Op):
MyOp = MyOp() MyOp = MyOp()
class NoInputOp(Op):
"""An Op to test the corner-case of an Op with no input."""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self):
return Apply(self, [], [MyType('test')()])
def perform(self, node, inputs, output_storage):
output_storage[0][0] = 'test Op no input'
class TestOp: class TestOp:
# Sanity tests # Sanity tests
...@@ -56,6 +80,11 @@ class TestOp: ...@@ -56,6 +80,11 @@ class TestOp:
if str(e) != "Error 1": if str(e) != "Error 1":
raise raise
def test_op_no_input(self):
x = NoInputOp()()
f = theano.function([], x)
rval = f()
assert rval == 'test Op no input'
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论