提交 b1d5f152 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

updates to PermuterowElements

上级 0a32d9bc
...@@ -5814,7 +5814,17 @@ class PermuteRowElements(Op): ...@@ -5814,7 +5814,17 @@ class PermuteRowElements(Op):
self._rec_perform(node, x, y, inverse, outs[0], curdim=0) self._rec_perform(node, x, y, inverse, outs[0], curdim=0)
def infer_shape(self, node, in_shapes): def infer_shape(self, node, in_shapes):
return [in_shapes[0]] shp_x = in_shapes[0]
shp_y = in_shapes[1]
if len(shp_x) > len(shp_y):
out_shape = shp_x
elif len(shp_x) < len(shp_y):
out_shape = shp_y
else:
out_shape = []
for i in range(len(shp_x)):
out_shape.append(maximum(shp_x[i], shp_y[i]))
return [out_shape]
def grad(self, inp, grads): def grad(self, inp, grads):
x, y, inverse = inp x, y, inverse = inp
......
...@@ -36,7 +36,8 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -36,7 +36,8 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements, tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc, ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1) dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
itensor3)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6136,10 +6137,33 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6136,10 +6137,33 @@ class TestInferShape(utt.InferShapeTester):
[PermuteRowElements()(admat, aivec, abool)], [PermuteRowElements()(admat, aivec, abool)],
[admat_val, aivec_val], PermuteRowElements) [admat_val, aivec_val], PermuteRowElements)
adtens_val = rand(3, 2, 5) adtens3 = dtensor3()
self._compile_and_check([adtens, aivec], adtens3_val = rand(3, 2, 5)
[PermuteRowElements()(adtens, aivec, abool)], self._compile_and_check([adtens3, aivec],
[adtens_val, aivec_val], PermuteRowElements) [PermuteRowElements()(adtens3, aivec, abool)],
[adtens3_val, aivec_val], PermuteRowElements)
aimat = imatrix()
perma = rng.permutation(5).astype('int32')
permb = rng.permutation(5).astype('int32')
permc = rng.permutation(5).astype('int32')
aimat_val = numpy.vstack((perma, permb, permc))
admat_val = rand(3, 5)
self._compile_and_check([admat, aimat],
[PermuteRowElements()(admat, aimat, abool)],
[admat_val, aimat_val], PermuteRowElements)
aitens3 = itensor3()
perma = rng.permutation(5).astype('int32')
permb = rng.permutation(5).astype('int32')
permc = rng.permutation(5).astype('int32')
bimat_val = numpy.vstack((perma, permb, permc))
aitens3_val = numpy.empty((2, 3, 5), 'int32')
aitens3_val[0, ::, ::] = aimat_val
aitens3_val[1, ::, ::] = bimat_val
self._compile_and_check([admat, aitens3],
[PermuteRowElements()(admat, aitens3, abool)],
[admat_val, aitens3_val], PermuteRowElements)
# ScalarFromTensor # ScalarFromTensor
aiscal = iscalar() aiscal = iscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论