提交 515249bd authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6289 from nouiz/fix-6288

[BUG] Fix gh-6287 and a regression introduced in gh-6218
......@@ -984,8 +984,8 @@ class GpuAllocEmpty(HideC, AllocEmpty):
output.tag.nan_guard_mode_check = False
return Apply(self, sh, [output])
def debug_perform(self, node, inputs, out_, ctx):
self.perform(node, inputs, out_, ctx)
def debug_perform(self, node, inputs, out_, params):
self.perform(node, inputs, out_, params)
out_[0][0][:] = -123456789
def perform(self, node, inputs, out_, params):
......
......@@ -6663,8 +6663,8 @@ class AllocEmpty(gof.Op):
output.tag.nan_guard_mode_check = False
return Apply(self, shape, [output])
def debug_perform(self, node, inputs, out_):
self.perform(node, inputs, out_)
def debug_perform(self, node, inputs, out_, params):
self.perform(node, inputs, out_, params)
out_[0][0].fill(-123456789)
def perform(self, node, inputs, out_, params):
......
......@@ -182,6 +182,7 @@ def local_dimshuffle_subtensor(node):
return False
new_order = node.op.new_order
# new order could be empty
# Verif that we don't change dimensions order.
if len(new_order) > 1:
past_dim = new_order[0]
for dim in new_order[1:]:
......@@ -216,7 +217,7 @@ def local_dimshuffle_subtensor(node):
j = 0
slice_i = -1
subtensor_removed_dims = 0
for idx in input_.owner.op.idx_list:
for i, idx in enumerate(input_.owner.op.idx_list):
if isinstance(idx, slice):
past_j = j
slice_i += 1
......@@ -228,7 +229,7 @@ def local_dimshuffle_subtensor(node):
# that's where we want to index with 0 if it is also at
# the same spot of a missing dim
if past_j == j and slice_i in missing_dims:
new_idx_list[j] = zero
new_idx_list[i] = zero
new_inputs += [zero]
else:
new_inputs += [input_.owner.inputs[1 + j]]
......
......@@ -207,3 +207,9 @@ def test_local_dimshuffle_subtensor():
topo = f.maker.fgraph.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
assert f(np.random.rand(5, 1, 4, 1), 2).shape == (4,)
# Test a corner case that had Theano return a bug.
x = tensor.dtensor4('x')
x = tensor.patternbroadcast(x, (False, True, False, False))
assert x[:,:, 0:3, ::-1].dimshuffle(0,2,3).eval({x: np.ones((5, 1, 6, 7))}).shape == (5, 3, 7)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论