提交 6e57a08d authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix too strict type check in `_sum_grad_over_bcasted_dims`

上级 4c7b4940
...@@ -2027,7 +2027,6 @@ def _sum_grad_over_bcasted_dims(x, gx): ...@@ -2027,7 +2027,6 @@ def _sum_grad_over_bcasted_dims(x, gx):
if gx.broadcastable != x.broadcastable: if gx.broadcastable != x.broadcastable:
x_dim_added = gx.ndim - x.ndim x_dim_added = gx.ndim - x.ndim
x_broad = (True,) * x_dim_added + x.broadcastable x_broad = (True,) * x_dim_added + x.broadcastable
assert sum(gx.broadcastable) <= sum(x_broad)
axis_to_sum = [] axis_to_sum = []
for i in range(gx.ndim): for i in range(gx.ndim):
if gx.broadcastable[i] is False and x_broad[i] is True: if gx.broadcastable[i] is False and x_broad[i] is True:
...@@ -2045,7 +2044,14 @@ def _sum_grad_over_bcasted_dims(x, gx): ...@@ -2045,7 +2044,14 @@ def _sum_grad_over_bcasted_dims(x, gx):
for i in range(x_dim_added): for i in range(x_dim_added):
assert gx.broadcastable[i] assert gx.broadcastable[i]
gx = gx.dimshuffle(*range(x_dim_added, gx.ndim)) gx = gx.dimshuffle(*range(x_dim_added, gx.ndim))
assert gx.broadcastable == x.broadcastable # Broadcastable flags of gx can be the same or more specific than x.
# Only unallowed case is x_dim_b == True and gx_dim_b == False.
assert not any(
x_dim_b and not gx_dim_b
for x_dim_b, gx_dim_b in zip(
x.type.broadcastable, gx.type.broadcastable, strict=True
)
), (x.type, gx.type)
return gx return gx
......
...@@ -12,7 +12,9 @@ import pytensor.tensor.basic as ptb ...@@ -12,7 +12,9 @@ import pytensor.tensor.basic as ptb
from pytensor import function from pytensor import function
from pytensor.compile import DeepCopyOp, shared from pytensor.compile import DeepCopyOp, shared
from pytensor.compile.io import In from pytensor.compile.io import In
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint from pytensor.printing import pprint
...@@ -22,6 +24,7 @@ from pytensor.tensor.blockwise import Blockwise ...@@ -22,6 +24,7 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf from pytensor.tensor.math import exp, isinf
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -1660,6 +1663,25 @@ class TestIncSubtensor: ...@@ -1660,6 +1663,25 @@ class TestIncSubtensor:
), ),
) )
def test_grad_broadcastable_specialization(self):
# Make sure gradient does not fail when gx has a more precise static_shape after indexing.
# This is a regression test for a bug reported in
# https://discourse.pymc.io/t/marginalized-mixture-wont-begin-sampling-throws-assertion-error/15969
x = vector("x") # Unknown write time shape = (2,)
out = x.zeros_like()
# Update a subtensor of unknown write time shape = (1,)
out = out[1:].set(exp(x[1:]))
out = specify_shape(out, 2)
gx = grad(out.sum(), x)
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
gx.eval({x: [1, 1]}, mode=mode),
[0, np.e],
)
class TestIncSubtensor1: class TestIncSubtensor1:
def setup_method(self): def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论