提交 d8b51df8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Split: Return disconnected gradient for split sizes

上级 ccf53d19
...@@ -2254,18 +2254,19 @@ class Split(COp): ...@@ -2254,18 +2254,19 @@ class Split(COp):
out_shapes.append(temp) out_shapes.append(temp)
return out_shapes return out_shapes
def connection_pattern(self, node):
n_out = len(node.outputs)
return [
[True] * n_out,
[True] * n_out,
[False] * n_out,
]
def L_op(self, inputs, outputs, g_outputs): def L_op(self, inputs, outputs, g_outputs):
"""Join the gradients along the axis that was used to split x.""" """Join the gradients along the axis that was used to split x."""
_x, axis, n = inputs _x, axis, _n = inputs
# If all the output gradients are disconnected, then so are the inputs # We have to convert disconnected outputs to zeros before joining them
if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs):
return [
DisconnectedType()(),
grad_undefined(self, 1, axis),
grad_undefined(self, 2, n),
]
# Else, we have to make them zeros before joining them
new_g_outputs = [] new_g_outputs = []
for o, g in zip(outputs, g_outputs, strict=True): for o, g in zip(outputs, g_outputs, strict=True):
if isinstance(g.type, DisconnectedType): if isinstance(g.type, DisconnectedType):
...@@ -2276,7 +2277,7 @@ class Split(COp): ...@@ -2276,7 +2277,7 @@ class Split(COp):
return [ return [
join(axis, *new_g_outputs), join(axis, *new_g_outputs),
grad_undefined(self, 1, axis), grad_undefined(self, 1, axis),
grad_undefined(self, 2, n), DisconnectedType()(),
] ]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
...@@ -6,6 +6,7 @@ import tests.unittest_tools as utt ...@@ -6,6 +6,7 @@ import tests.unittest_tools as utt
from pytensor import config, function from pytensor import config, function
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.graph import rewrite_graph, vectorize_graph from pytensor.graph import rewrite_graph, vectorize_graph
from pytensor.graph.op import io_connection_pattern
from pytensor.tensor.reshape import ( from pytensor.tensor.reshape import (
_analyze_axes_list, _analyze_axes_list,
join_dims, join_dims,
...@@ -289,3 +290,12 @@ class TestPack: ...@@ -289,3 +290,12 @@ class TestPack:
for input_val, output_val in zip(input_dict.values(), output_vals, strict=True): for input_val, output_val in zip(input_dict.values(), output_vals, strict=True):
np.testing.assert_allclose(input_val, output_val) np.testing.assert_allclose(input_val, output_val)
def test_unpack_connection():
x = pt.vector("x")
d0 = pt.scalar("d0", dtype=int)
d1 = pt.scalar("d1", dtype=int)
x0, x1 = pt.unpack(x, axes=None, packed_shapes=[d0, d1])
out = x0.sum() + x1.sum()
assert io_connection_pattern([x, d0, d1], [out]) == [[True], [False], [False]]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论