提交 a50646d5 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op Shape

上级 565202f8
...@@ -2064,6 +2064,9 @@ class Shape(Op): ...@@ -2064,6 +2064,9 @@ class Shape(Op):
out, = out_ out, = out_
out[0] = theano._asarray(x.shape, dtype='int64') out[0] = theano._asarray(x.shape, dtype='int64')
def infer_shape(self, node, in_shapes):
return [[len(in_shapes[0])]]
def grad(self, inp, grads): def grad(self, inp, grads):
return [None] return [None]
......
...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye) tile, patternbroadcast, Eye, Shape)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6076,6 +6076,13 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6076,6 +6076,13 @@ class TestInferShape(utt.InferShapeTester):
[Eye()(aiscal, biscal, ciscal)], [Eye()(aiscal, biscal, ciscal)],
[3, 5, 0], Eye) [3, 5, 0], Eye)
# Shape
# 'opt.Makevector' precludes optimizer from disentangling
# elements of shape
self._compile_and_check([adtens],
[Shape()(adtens)],
[adtens_val], (opt.MakeVector, Shape))
if __name__ == '__main__': if __name__ == '__main__':
t = TestInferShape('setUp') t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论