提交 4ea384c3 authored 作者: John Salvatier's avatar John Salvatier

add infer_shape test to as_op decorator

上级 4fb3b885
...@@ -43,3 +43,24 @@ class OpDecoratorTests(unittest.TestCase): ...@@ -43,3 +43,24 @@ class OpDecoratorTests(unittest.TestCase):
assert allclose(r, r0), (r, r0) assert allclose(r, r0), (r, r0)
def test_infer_shape(self):
x = dmatrix('x')
x.tag.test_value=np.zeros((2,2))
y = dvector('y')
y.tag.test_value=[0,0]
def infer_shape(node, shapes):
x,y = shapes
return [y]
@as_op([dmatrix, dvector], dvector, infer_shape)
def diag_mult(x, y):
return np.diag(x) * y
fn = function([x, y], diag_mult(x, y).shape)
r = fn([[1.5, 5],[2, 2]], [1, 100])
r0 = (2,)
print r
assert allclose(r, r0), (r, r0)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论