提交 68ab8d91 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

New algortihm for the perform that avoids one of the reshapes.

上级 0f1c32b5
......@@ -143,10 +143,6 @@ class DimShuffle(Op):
# list of dimensions of the input to drop
self.drop = []
# this maps i before dropping dimensions to j after dropping dimensions
# so self.shuffle can be set properly later on
i2j = {}
j = 0
for i, b in enumerate(input_broadcastable):
if i not in new_order:
# we want to drop this dimension because it's not a value in
......@@ -158,14 +154,9 @@ class DimShuffle(Op):
raise ValueError(
"You cannot drop a non-broadcastable dimension.",
(input_broadcastable, new_order))
else:
i2j[i] = j
j += 1
# transposition of non-broadcastable dimensions
# This is how the dimensions will be permuted, without accounting for
# the extra 'x' broadcastable dimensions to insert.
self.shuffle = [i2j[x] for x in new_order if x != 'x']
# this is the list of the original dimensions that we keep
self.shuffle = [x for x in new_order if x != 'x']
# list of dimensions of the output that are broadcastable and were not
# in the original input
......@@ -237,16 +228,12 @@ class DimShuffle(Op):
res = input
if type(res) != numpy.ndarray and type(res) != numpy.memmap:
raise TypeError(res)
shape = list(res.shape)
for drop in reversed(self.drop):
shape.pop(drop)
res = res.reshape(shape)
# transpose
res = res.transpose(self.shuffle)
res = res.transpose(self.shuffle+self.drop)
# augment
shape = list(res.shape)
shape = list(res.shape[:len(self.shuffle)])
for augm in self.augment:
shape.insert(augm, 1)
res = res.reshape(shape)
......@@ -259,9 +246,6 @@ class DimShuffle(Op):
def infer_shape(self, node, shapes):
ishp, = shapes
ishp = list(ishp)
for drop in reversed(self.drop):
del ishp[drop]
# transpose
rval = [ishp[i] for i in self.shuffle]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论