提交 f585ebd2 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op Rebroadcast

上级 a2827cec
...@@ -35,7 +35,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -35,7 +35,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements, tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar) ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6155,6 +6155,14 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6155,6 +6155,14 @@ class TestInferShape(utt.InferShapeTester):
[TensorFromScalar()(aiscal)], [TensorFromScalar()(aiscal)],
[4.], TensorFromScalar) [4.], TensorFromScalar)
# Rebroadcast:
adtens4 = dtensor4()
adict = [(0, False), (1, True), (2, False), (3, True)]
adtens4_val = rand(2, 1, 3, 1)
self._compile_and_check([adtens4],
[Rebroadcast(*adict)(adtens4)],
[adtens4_val], Rebroadcast)
if __name__ == '__main__': if __name__ == '__main__':
t = TestInferShape('setUp') t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论