提交 6a05466b authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op PermuteRowElements

上级 b085ebbf
...@@ -5807,6 +5807,9 @@ class PermuteRowElements(Op): ...@@ -5807,6 +5807,9 @@ class PermuteRowElements(Op):
self._rec_perform(node, x, y, inverse, outs[0], curdim=0) self._rec_perform(node, x, y, inverse, outs[0], curdim=0)
def infer_shape(self, node, in_shapes):
return [in_shapes[0]]
def grad(self, inp, grads): def grad(self, inp, grads):
x, y, inverse = inp x, y, inverse = inp
gz, = grads gz, = grads
......
...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Default, Dot) tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6120,6 +6120,26 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6120,6 +6120,26 @@ class TestInferShape(utt.InferShapeTester):
[Join()(aiscal, admat, bdmat, cdmat)], [Join()(aiscal, admat, bdmat, cdmat)],
[aiscal_val, admat_val, bdmat_val, cdmat_val], Join) [aiscal_val, admat_val, bdmat_val, cdmat_val], Join)
# PermuteRowElements
abool = True
rng = numpy.random.RandomState(utt.fetch_seed())
advec_val = rand(5)
aivec_val = rng.permutation(5).astype('int32')
self._compile_and_check([advec, aivec],
[PermuteRowElements()(advec, aivec, abool)],
[advec_val, aivec_val], PermuteRowElements)
admat_val = rand(3, 5)
self._compile_and_check([admat, aivec],
[PermuteRowElements()(admat, aivec, abool)],
[admat_val, aivec_val], PermuteRowElements)
adtens_val = rand(3, 2, 5)
self._compile_and_check([adtens, aivec],
[PermuteRowElements()(adtens, aivec, abool)],
[adtens_val, aivec_val], PermuteRowElements)
if __name__ == '__main__': if __name__ == '__main__':
t = TestInferShape('setUp') t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论