提交 e53f1103 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op SpecifyShape

上级 1c3a65e0
...@@ -36,7 +36,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -36,7 +36,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements, tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc, ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3) dtensor3, SpecifyShape)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6229,6 +6229,13 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6229,6 +6229,13 @@ class TestInferShape(utt.InferShapeTester):
[ARange('int64')(aiscal, biscal, ciscal)], [ARange('int64')(aiscal, biscal, ciscal)],
[0, 0, 1], ARange) [0, 0, 1], ARange)
# SpecifyShape
aivec_val = [3, 4, 2, 5]
adtens4_val = rand(*aivec_val)
self._compile_and_check([adtens4, aivec],
[SpecifyShape()(adtens4, aivec)],
[adtens4_val, aivec_val], SpecifyShape)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论