提交 ee884b87 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Jesse Grabowski

Fix Elemwise and Blockwise gradient for Ops with mixed discrete and continuous output types

上级 676296c6
......@@ -18,7 +18,7 @@ from pytensor.graph.replace import (
from pytensor.scalar import ScalarType
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor
from pytensor.tensor.type import TensorType, tensor
from pytensor.tensor.utils import (
_parse_gufunc_signature,
broadcast_static_dim_lengths,
......@@ -256,6 +256,10 @@ class Blockwise(Op):
as_core(ograd, core_ograd)
for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True)
]
# FIXME: These core_outputs do not depend on core_inputs, not pretty
# It's not neccessarily a problem because if they are referenced by the gradient,
# they get replaced later in vectorize. But if the Op was to make any decision
# by introspecting the dependencies of output on inputs it would fail badly!
core_outputs = core_node.outputs
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
......@@ -283,27 +287,6 @@ class Blockwise(Op):
# Compute grad with respect to broadcasted input
rval = self._bgrad(inputs, outs, ograds)
# TODO: (Borrowed from Elemwise) make sure that zeros are clearly identifiable
# to the gradient.grad method when the outputs have
# some integer and some floating point outputs
if any(out.type.dtype not in continuous_dtypes for out in outs):
# For integer output, return value may only be zero or undefined
# We don't bother with trying to check that the scalar ops
# correctly returned something that evaluates to 0, we just make
# the return value obviously zero so that gradient.grad can tell
# this op did the right thing.
new_rval = []
for elem, inp in zip(rval, inputs, strict=True):
if isinstance(elem.type, NullType | DisconnectedType):
new_rval.append(elem)
else:
elem = inp.zeros_like()
if str(elem.type.dtype) not in continuous_dtypes:
elem = elem.astype(config.floatX)
assert str(elem.type.dtype) not in discrete_dtypes
new_rval.append(elem)
return new_rval
# Sum out the broadcasted dimensions
batch_ndims = self.batch_ndim(outs[0].owner)
batch_shape = outs[0].type.shape[:batch_ndims]
......
......@@ -515,27 +515,6 @@ class Elemwise(OpenMPOp):
# Compute grad with respect to broadcasted input
rval = self._bgrad(inputs, outs, ograds)
# TODO: make sure that zeros are clearly identifiable
# to the gradient.grad method when the outputs have
# some integer and some floating point outputs
if any(out.type.dtype not in continuous_dtypes for out in outs):
# For integer output, return value may only be zero or undefined
# We don't bother with trying to check that the scalar ops
# correctly returned something that evaluates to 0, we just make
# the return value obviously zero so that gradient.grad can tell
# this op did the right thing.
new_rval = []
for elem, ipt in zip(rval, inputs, strict=True):
if isinstance(elem.type, NullType | DisconnectedType):
new_rval.append(elem)
else:
elem = ipt.zeros_like()
if str(elem.type.dtype) not in continuous_dtypes:
elem = elem.astype(config.floatX)
assert str(elem.type.dtype) not in discrete_dtypes
new_rval.append(elem)
return new_rval
# sum out the broadcasted dimensions
for i, ipt in enumerate(inputs):
if isinstance(rval[i].type, NullType | DisconnectedType):
......
......@@ -12,7 +12,7 @@ from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
......@@ -603,3 +603,26 @@ class TestInplace:
# Confirm input was destroyed
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
def test_gradient_mixed_discrete_output_core_op():
class MixedDtypeCoreOp(Op):
gufunc_signature = "()->(),()"
itypes = [scalar().type]
otypes = [scalar().type, scalar(dtype=int).type]
def perform(self, node, inputs, outputs):
raise NotImplementedError()
def L_op(self, inputs, outputs, output_gradients):
return [ones_like(inputs[0]) * output_gradients[0]]
op = Blockwise(MixedDtypeCoreOp())
x = vector("x")
y, _ = op(x)
np.testing.assert_array_equal(
grad(y.sum(), x).eval({x: np.full(12, np.nan, dtype=config.floatX)}),
np.ones(12, dtype=config.floatX),
strict=True,
)
......@@ -11,16 +11,16 @@ import pytensor
import pytensor.scalar as ps
import pytensor.tensor as pt
import tests.unittest_tools as utt
from pytensor import In, Out
from pytensor import In, Out, config, grad
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.link.basic import PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.npy_2_compat import numpy_maxdims
from pytensor.scalar import ScalarOp, float32, float64, int32, int64
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import get_scalar_constant_value, second
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
......@@ -1068,3 +1068,28 @@ def test_c_careduce_benchmark(axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="FAST_RUN", benchmark=benchmark
)
def test_gradient_mixed_discrete_output_scalar_op():
class MixedDtypeScalarOp(ScalarOp):
def make_node(self, *inputs):
float_op = float64 if config.floatX == "float64" else float32
int_op = int64 if config.floatX == "int64" else int32
inputs = [float_op()]
outputs = [float_op(), int_op()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
raise NotImplementedError()
def L_op(self, inputs, outputs, output_gradients):
return [inputs[0].ones_like() * output_gradients[0]]
op = Elemwise(MixedDtypeScalarOp())
x = vector("x")
y, _ = op(x)
np.testing.assert_array_equal(
grad(y.sum(), x).eval({x: np.full((12,), np.nan, dtype=config.floatX)}),
np.ones((12,), dtype=config.floatX),
strict=True,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论