提交 e3430b61 authored 作者: Ramana.S's avatar Ramana.S

Moved the mae_node method

上级 57a2f9c8
......@@ -34,26 +34,6 @@ class OpDecoratorTests(utt.InferShapeTester):
assert allclose(r, r0), (r, r0)
def test_make_node(self):
x = dmatrix('x')
x.tag.test_value = np.zeros((2, 2))
y = dvector('y')
y.tag.test_value = [0, 0]
with self.assertRaisesRegexp(NotImplementedError, "itypes not defined") :
@as_op(itypes=[], otypes=dvector)
def none_itypes(x,y):
return np.dot(x,y)
none_itypes(x,y)
with self.assertRaisesRegexp(NotImplementedError, "otypes not defined") :
@as_op(itypes=[dmatrix, dvector], otypes=[])
def none_otypes(x,y):
return np.dot(x,y)
none_otypes(x, y)
def test_2arg(self):
x = dmatrix('x')
x.tag.test_value = np.zeros((2, 2))
......
......@@ -765,6 +765,10 @@ class Op(utils.object2, PureOp, CLinkerOp):
Convenience class to bundle `PureOp` and `CLinkerOp`.
"""
def __new__(cls, *args, **kwargs):
# this function exists to silently and transparently ensure that all
# existing Ops get a _op_use_c_code attribute
......@@ -932,6 +936,23 @@ class Op(utils.object2, PureOp, CLinkerOp):
# condition: either there was no c_code, or it failed
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
def make_node(self, *inputs):
if not hasattr(self, 'itypes'):
raise NotImplementedError("itypes not defined")
if not hasattr(self, 'otypes') :
raise NotImplementedError("otypes not defined")
if len(inputs) != len(self.itypes):
raise ValueError("We expected %d inputs but got %d." %
(len(self.itypes), len(inputs)))
if not all(inp.type == it for inp, it in zip(inputs, self.itypes)):
raise TypeError(
"We expected inputs of types '%s' but got types '%s' " %
(str([inp.type for inp in inputs]), str(self.itypes)))
return theano.Apply(self, inputs, [o() for o in self.otypes])
def get_test_value(v):
......
......@@ -9,6 +9,7 @@ from six import string_types
from theano.gof.type import Type, Generic
from theano.gof.graph import Apply, Variable
import theano.tensor as T
from theano.compile import as_op
from theano import scalar
from theano import shared
......@@ -57,7 +58,8 @@ class MyType(Type):
class MyOp(Op):
__props__ = ()
'''
def make_node(self, *inputs):
inputs = list(map(as_variable, inputs))
for input in inputs:
......@@ -65,6 +67,7 @@ class MyOp(Op):
raise Exception("Error 1")
outputs = [MyType(sum([input.type.thingy for input in inputs]))()]
return Apply(self, inputs, outputs)
'''
MyOp = MyOp()
......@@ -381,5 +384,28 @@ def test_debug_error_message():
finally:
config.compute_test_value = prev_value
'''
def test_make_node():
x = T.dmatrix('x')
x.tag.test_value = numpy.zeros((2, 2))
y = T.dvector('y')
y.tag.test_value = [0, 0]
with unittest.assertRaisesRegexp(NotImplementedError, "itypes not defined") :
@as_op(itypes=[], otypes=T.dvector)
def none_itypes(x,y):
return numpy.dot(x,y)
none_itypes(x,y)
with assertRaisesRegexp(NotImplementedError, "otypes not defined") :
@as_op(itypes=[T.dmatrix, T.dvector], otypes=[])
def none_otypes(x,y):
return numpy.dot(x,y)
none_otypes(x, y)
'''
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论