提交 c4fb7e56 authored 作者: Ramana.S's avatar Ramana.S

itypes and otypes checked and added test

上级 91f08497
...@@ -519,6 +519,13 @@ class FromFunctionOp(gof.Op): ...@@ -519,6 +519,13 @@ class FromFunctionOp(gof.Op):
return 'FromFunctionOp{%s}' % self.__fn.__name__ return 'FromFunctionOp{%s}' % self.__fn.__name__
def make_node(self, *inputs): def make_node(self, *inputs):
if not self.itypes:
raise NotImplementedError("itypes not defined")
if not self.otypes :
raise NotImplementedError("otypes not defined")
if len(inputs) != len(self.itypes): if len(inputs) != len(self.itypes):
raise ValueError("We expected %d inputs but got %d." % raise ValueError("We expected %d inputs but got %d." %
(len(self.itypes), len(inputs))) (len(self.itypes), len(inputs)))
......
...@@ -34,6 +34,25 @@ class OpDecoratorTests(utt.InferShapeTester): ...@@ -34,6 +34,25 @@ class OpDecoratorTests(utt.InferShapeTester):
assert allclose(r, r0), (r, r0) assert allclose(r, r0), (r, r0)
def test_make_node(self):
x = dmatrix('x')
x.tag.test_value = np.zeros((2, 2))
y = dvector('y')
y.tag.test_value = [0, 0]
with self.assertRaises(NotImplementedError):
@as_op(itypes=[dmatrix, dvector], otypes=[])
def none_otypes(x,y):
return np.dot(x,y)
@as_op(itypes=[], otypes=dvector)
def none_itypes(x,y):
return np.dot(x,y)
none_itypes(x,y)
none_otypes(x, y)
def test_2arg(self): def test_2arg(self):
x = dmatrix('x') x = dmatrix('x')
x.tag.test_value = np.zeros((2, 2)) x.tag.test_value = np.zeros((2, 2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论