提交 f8f894ad authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op MaxAndArgmax

上级 67ab608a
...@@ -35,7 +35,8 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -35,7 +35,8 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements, tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc) ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6180,6 +6181,35 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6180,6 +6181,35 @@ class TestInferShape(utt.InferShapeTester):
[adscal_val, aiscal_val, biscal_val, [adscal_val, aiscal_val, biscal_val,
ciscal_val, discal_val], Alloc) ciscal_val, discal_val], Alloc)
# MaxAndArgmax,
# Note: axis as a tensor.iscalar or constant conflicts with
# make_node in basic
adtens3 = dtensor3()
aiscal = iscalar()
aconst = 1
aiscal_val = randint(0, 2, size=())
adtens3_val = rand(4, 5, 3)
self._compile_and_check([adtens3],
MaxAndArgmax()(adtens3, None),
[adtens3_val], MaxAndArgmax)
self._compile_and_check([adtens3],
MaxAndArgmax()(adtens3, 0),
[adtens3_val], MaxAndArgmax)
self._compile_and_check([adtens3],
MaxAndArgmax()(adtens3, 1),
[adtens3_val], MaxAndArgmax)
self._compile_and_check([adtens3],
MaxAndArgmax()(adtens3, 2),
[adtens3_val], MaxAndArgmax)
self._compile_and_check([adtens3],
MaxAndArgmax()(adtens3, [0, 1, 2]),
[adtens3_val], MaxAndArgmax)
if __name__ == '__main__': if __name__ == '__main__':
t = TestInferShape('setUp') t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论