Unverified 提交 9f80bdcd authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Fix bug in gradient of Blockwise'd Scan (#1482)

* Avoid pytest warning for variable name * Respect core type shape in gradient of Blockwise * Refactor Blockwise L_op
上级 b218ffe3
...@@ -344,81 +344,66 @@ class Blockwise(COp): ...@@ -344,81 +344,66 @@ class Blockwise(COp):
return [[True for _ in node.outputs] for _ in node.inputs] return [[True for _ in node.outputs] for _ in node.inputs]
def _bgrad(self, inputs, outputs, ograds): def L_op(self, inputs, outputs, output_gradients):
# Grad, with respect to broadcasted versions of inputs batch_ndim = self.batch_ndim(outputs[0].owner)
def as_core(t, core_t):
# Inputs could be NullType or DisconnectedType
if isinstance(t.type, NullType | DisconnectedType):
return t
return core_t.type()
# Obtain core_op gradients
with config.change_flags(compute_test_value="off"): with config.change_flags(compute_test_value="off"):
safe_inputs = [
tensor(dtype=inp.type.dtype, shape=(None,) * len(sig))
for inp, sig in zip(inputs, self.inputs_sig, strict=True)
]
core_node = self._create_dummy_core_node(safe_inputs)
core_inputs = [ core_inputs = [
as_core(inp, core_inp) tensor(
for inp, core_inp in zip(inputs, core_node.inputs, strict=True) dtype=inp.type.dtype,
] shape=inp.type.shape[batch_ndim:],
core_ograds = [ )
as_core(ograd, core_ograd) for inp in inputs
for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True)
] ]
# FIXME: These core_outputs do not depend on core_inputs, not pretty core_outputs = self._create_dummy_core_node(core_inputs).outputs
# It's not neccessarily a problem because if they are referenced by the gradient,
# they get replaced later in vectorize. But if the Op was to make any decision # Define core output_gradients, but keep original disconnected/null output_gradients (if any)
# by introspecting the dependencies of output on inputs it would fail badly! core_output_gradients = [
core_outputs = core_node.outputs output_grad
if isinstance(output_grad.type, NullType | DisconnectedType)
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) else core_output.type()
for output_grad, core_output in zip(
igrads = vectorize_graph( output_gradients, core_outputs, strict=True
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
replace=dict(
zip(
core_inputs + core_outputs + core_ograds,
inputs + outputs + ograds,
strict=True,
) )
), ]
)
igrads_iter = iter(igrads)
return [
None if core_igrad is None else next(igrads_iter)
for core_igrad in core_igrads
]
def L_op(self, inputs, outs, ograds): core_input_gradients = self.core_op.L_op(
from pytensor.tensor.math import sum as pt_sum core_inputs, core_outputs, core_output_gradients
)
# Compute grad with respect to broadcasted input # Vectorize core gradients to original inputs
rval = self._bgrad(inputs, outs, ograds) input_gradients = list(
vectorize_graph(
core_input_gradients,
replace=dict(
zip(
core_inputs + core_outputs + core_output_gradients,
inputs + outputs + output_gradients,
strict=True,
)
),
)
)
# Sum out the broadcasted dimensions # Sum out the broadcasted batch dimensions
batch_ndims = self.batch_ndim(outs[0].owner) batch_shape = outputs[0].type.shape[:batch_ndim]
batch_shape = outs[0].type.shape[:batch_ndims]
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)): for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
if isinstance(rval[i].type, NullType | DisconnectedType): if isinstance(input_gradients[i].type, NullType | DisconnectedType):
continue continue
assert inp.type.ndim == batch_ndims + len(sig) assert inp.type.ndim == batch_ndim + len(sig)
to_sum = [ if to_sum := [
j j
for j, (inp_s, out_s) in enumerate( for j, (inp_s, out_s) in enumerate(
zip(inp.type.shape, batch_shape, strict=False) zip(inp.type.shape, batch_shape, strict=False)
) )
if inp_s == 1 and out_s != 1 if inp_s == 1 and out_s != 1
] ]:
if to_sum: input_gradients[i] = input_gradients[i].sum(axis=to_sum, keepdims=True)
rval[i] = pt_sum(rval[i], axis=to_sum, keepdims=True)
return rval return input_gradients
def _create_node_gufunc(self, node: Apply, impl) -> Callable: def _create_node_gufunc(self, node: Apply, impl) -> Callable:
"""Define (or retrieve) the node gufunc used in `perform`. """Define (or retrieve) the node gufunc used in `perform`.
......
...@@ -6,11 +6,11 @@ import pytest ...@@ -6,11 +6,11 @@ import pytest
import scipy.linalg import scipy.linalg
import pytensor import pytensor
from pytensor import In, config, function from pytensor import In, config, function, scan
from pytensor.compile import get_default_mode, get_mode from pytensor.compile import get_default_mode, get_mode
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_graph, vectorize_node
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
...@@ -162,13 +162,13 @@ class MyTestOp(Op): ...@@ -162,13 +162,13 @@ class MyTestOp(Op):
raise NotImplementedError("Test Op should not be present in final graph") raise NotImplementedError("Test Op should not be present in final graph")
test_op = MyTestOp() my_test_op = MyTestOp()
def test_vectorize_node_default_signature(): def test_vectorize_node_default_signature():
vec = tensor(shape=(None,)) vec = tensor(shape=(None,))
mat = tensor(shape=(5, None)) mat = tensor(shape=(5, None))
node = test_op.make_node(vec, mat) node = my_test_op.make_node(vec, mat)
vect_node = vectorize_node(node, mat, mat) vect_node = vectorize_node(node, mat, mat)
assert isinstance(vect_node.op, Blockwise) and isinstance( assert isinstance(vect_node.op, Blockwise) and isinstance(
...@@ -179,9 +179,9 @@ def test_vectorize_node_default_signature(): ...@@ -179,9 +179,9 @@ def test_vectorize_node_default_signature():
with pytest.raises( with pytest.raises(
ValueError, match="Signature not provided nor found in core_op MyTestOp" ValueError, match="Signature not provided nor found in core_op MyTestOp"
): ):
Blockwise(test_op) Blockwise(my_test_op)
vect_node = Blockwise(test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat) vect_node = Blockwise(my_test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat)
assert vect_node.outputs[0].type.shape == ( assert vect_node.outputs[0].type.shape == (
5, 5,
None, None,
...@@ -198,7 +198,7 @@ def test_blockwise_shape(): ...@@ -198,7 +198,7 @@ def test_blockwise_shape():
inp_test = np.zeros((5, 4, 3), dtype=config.floatX) inp_test = np.zeros((5, 4, 3), dtype=config.floatX)
# Shape can be inferred from inputs # Shape can be inferred from inputs
op = Blockwise(test_op, signature="(m, n) -> (n, m)") op = Blockwise(my_test_op, signature="(m, n) -> (n, m)")
out = op(inp) out = op(inp)
assert out.type.shape == (5, None, None) assert out.type.shape == (5, None, None)
...@@ -210,7 +210,7 @@ def test_blockwise_shape(): ...@@ -210,7 +210,7 @@ def test_blockwise_shape():
assert tuple(shape_fn(inp_test)) == (5, 3, 4) assert tuple(shape_fn(inp_test)) == (5, 3, 4)
# Shape can only be partially inferred from inputs # Shape can only be partially inferred from inputs
op = Blockwise(test_op, signature="(m, n) -> (m, k)") op = Blockwise(my_test_op, signature="(m, n) -> (m, k)")
out = op(inp) out = op(inp)
assert out.type.shape == (5, None, None) assert out.type.shape == (5, None, None)
...@@ -233,7 +233,7 @@ def test_blockwise_shape(): ...@@ -233,7 +233,7 @@ def test_blockwise_shape():
inp1_test = np.zeros((7, 1, 4, 3), dtype=config.floatX) inp1_test = np.zeros((7, 1, 4, 3), dtype=config.floatX)
inp2_test = np.zeros((1, 5, 4, 3), dtype=config.floatX) inp2_test = np.zeros((1, 5, 4, 3), dtype=config.floatX)
op = Blockwise(test_op, signature="(m, n), (m, n) -> (n, m), (m, k)") op = Blockwise(my_test_op, signature="(m, n), (m, n) -> (n, m), (m, k)")
outs = op(inp1, inp2) outs = op(inp1, inp2)
assert outs[0].type.shape == (7, 5, None, None) assert outs[0].type.shape == (7, 5, None, None)
assert outs[1].type.shape == (7, 5, None, None) assert outs[1].type.shape == (7, 5, None, None)
...@@ -650,3 +650,51 @@ def test_gradient_mixed_discrete_output_core_op(): ...@@ -650,3 +650,51 @@ def test_gradient_mixed_discrete_output_core_op():
np.ones(12, dtype=config.floatX), np.ones(12, dtype=config.floatX),
strict=True, strict=True,
) )
def test_blockwise_grad_core_type():
class StrictCoreTypeOp(Op):
def make_node(self, x):
assert x.type.shape[-1] == 2
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, output_storage):
output_storage[0][0] = inputs[0] + 1
def L_op(self, inputs, outputs, output_grads):
[x] = inputs
assert x.type.shape == (2,)
return [x.zeros_like()]
strict_core_type_op = StrictCoreTypeOp()
block_strict_core_type_op = Blockwise(strict_core_type_op, signature="(a)->(a)")
x = tensor("x", shape=(5, 2), dtype="float64")
y = block_strict_core_type_op(x)
assert y.type.shape == (5, 2)
grad_y = grad(y.sum(), x)
assert grad_y.type.shape == (5, 2)
np.testing.assert_allclose(
grad_y.eval({x: np.ones((5, 2))}),
np.zeros((5, 2)),
)
def test_scan_gradient_core_type():
n_steps = 3
seq = tensor("seq", shape=(n_steps, 1), dtype="float64")
out, _ = scan(
lambda s: s,
sequences=[seq],
n_steps=n_steps,
)
vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64")
vec_out = vectorize_graph(out, replace={seq: vec_seq})
grad_sit_sot0 = grad(vec_out.sum(), vec_seq)
np.testing.assert_allclose(
grad_sit_sot0.eval({vec_seq: np.ones((4, n_steps, 1))}),
np.ones((4, n_steps, 1)),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论