提交 179d464b authored 作者: Frederic's avatar Frederic

fix gh-1461

上级 39db1f8e
...@@ -7918,6 +7918,15 @@ class Dot(Op): ...@@ -7918,6 +7918,15 @@ class Dot(Op):
xgrad = dot(gz, y.T) xgrad = dot(gz, y.T)
ygrad = dot(x.T, gz) ygrad = dot(x.T, gz)
# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable:
xgrad = patternbroadcast(xgrad, x.broadcastable)
if ygrad.broadcastable != y.broadcastable:
ygrad = patternbroadcast(ygrad, y.broadcastable)
rval = xgrad, ygrad rval = xgrad, ygrad
for elem in rval: for elem in rval:
......
...@@ -5007,6 +5007,12 @@ class t_dot(unittest.TestCase): ...@@ -5007,6 +5007,12 @@ class t_dot(unittest.TestCase):
tval = val_for(t) tval = val_for(t)
f(xval, yval, tval) # debugmode checks result f(xval, yval, tval) # debugmode checks result
if (dtype0.startswith('float') and
dtype1.startswith('float')):
g = grad(z.sum(), x)
assert g.broadcastable == x.broadcastable
g = grad(z.sum(), y)
assert g.broadcastable == y.broadcastable
class T_tensorfromscalar(unittest.TestCase): class T_tensorfromscalar(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论