提交 565202f8 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op Eye

上级 f84b04d4
...@@ -2879,6 +2879,10 @@ class Eye(gof.Op): ...@@ -2879,6 +2879,10 @@ class Eye(gof.Op):
out, = out_ out, = out_
out[0] = numpy.eye(n, m, k, dtype=self.dtype) out[0] = numpy.eye(n, m, k, dtype=self.dtype)
def infer_shape(self, node, in_shapes):
out_shape = [node.inputs[0], node.inputs[1]]
return [out_shape]
def grad(self, inp, grads): def grad(self, inp, grads):
return [None, None, None] return [None, None, None]
......
...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast) tile, patternbroadcast, Eye)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6060,6 +6060,21 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6060,6 +6060,21 @@ class TestInferShape(utt.InferShapeTester):
[Flatten(outdim)(adtens)], [Flatten(outdim)(adtens)],
[adtens_val], Flatten) [adtens_val], Flatten)
# Eye
aiscal = iscalar()
biscal = iscalar()
ciscal = iscalar()
self._compile_and_check([aiscal, biscal, ciscal],
[Eye()(aiscal, biscal, ciscal)],
[4, 4, 0], Eye)
self._compile_and_check([aiscal, biscal, ciscal],
[Eye()(aiscal, biscal, ciscal)],
[4, 5, 0], Eye)
self._compile_and_check([aiscal, biscal, ciscal],
[Eye()(aiscal, biscal, ciscal)],
[3, 5, 0], Eye)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论