提交 d498f7e0 authored 作者: ChienliMa's avatar ChienliMa

still have bug

上级 eb8cadb2
...@@ -764,6 +764,7 @@ class FillDiagonalOffset(gof.Op): ...@@ -764,6 +764,7 @@ class FillDiagonalOffset(gof.Op):
a = inputs[0].copy() a = inputs[0].copy()
val = inputs[1] val = inputs[1]
offset = inputs[2] offset = inputs[2]
height, width = a.shape
# offset should be an integer # offset should be an integer
if offset % 1 != 0: if offset % 1 != 0:
...@@ -775,10 +776,12 @@ class FillDiagonalOffset(gof.Op): ...@@ -775,10 +776,12 @@ class FillDiagonalOffset(gof.Op):
# the offset function is only implemented for matrices # the offset function is only implemented for matrices
if offset >= 0: if offset >= 0:
start = offset start = offset
num_of_step = min( min(width,height), width - offset)
else: else:
start = - offset * a.shape[0] start = - offset * a.shape[0]
num_of_step = min( min(width,height), height + offset)
step = a.shape[1] + 1 step = a.shape[1] + 1
end = a.shape[1] * a.shape[1] end = start + step * num_of_step
# Write the value out into the diagonal. # Write the value out into the diagonal.
a.flat[start:end:step] = val a.flat[start:end:step] = val
......
...@@ -42,7 +42,7 @@ class TestFillDiagonalOffset(utt.InferShapeTester): ...@@ -42,7 +42,7 @@ class TestFillDiagonalOffset(utt.InferShapeTester):
out = f(a, val, test_offset) out = f(a, val, test_offset)
# We can't use numpy.fill_diagonal as it is bugged. # We can't use numpy.fill_diagonal as it is bugged.
assert numpy.allclose(numpy.diag(out, test_offset), val) assert numpy.allclose(numpy.diag(out, test_offset), val)
pdb.set_trace() #pdb.set_trace()
assert (out == val).sum() == min(a.shape) assert (out == val).sum() == min(a.shape)
def test_gradient(self): def test_gradient(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论