提交 2faa56a4 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Do not redefine DisconnectedType everytime

上级 d8b51df8
...@@ -495,7 +495,7 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o ...@@ -495,7 +495,7 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType, disconnected_type
class TransposeAndSumOp(Op): class TransposeAndSumOp(Op):
__props__ = () __props__ = ()
...@@ -539,13 +539,13 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o ...@@ -539,13 +539,13 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o
out1_grad, out2_grad = output_grads out1_grad, out2_grad = output_grads
if isinstance(out1_grad.type, DisconnectedType): if isinstance(out1_grad.type, DisconnectedType):
x_grad = DisconnectedType()() x_grad = disconnected_type()
else: else:
# Transpose the last two dimensions of the output gradient # Transpose the last two dimensions of the output gradient
x_grad = pt.swapaxes(out1_grad, -1, -2) x_grad = pt.swapaxes(out1_grad, -1, -2)
if isinstance(out2_grad.type, DisconnectedType): if isinstance(out2_grad.type, DisconnectedType):
y_grad = DisconnectedType()() y_grad = disconnected_type()
else: else:
# Broadcast the output gradient to the same shape as y # Broadcast the output gradient to the same shape as y
y_grad = pt.broadcast_to(pt.expand_dims(out2_grad, -1), y.shape) y_grad = pt.broadcast_to(pt.expand_dims(out2_grad, -1), y.shape)
......
import numpy as np import numpy as np
from pytensor.gradient import DisconnectedType from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import as_tensor_variable
...@@ -142,7 +142,7 @@ class PdbBreakpoint(Op): ...@@ -142,7 +142,7 @@ class PdbBreakpoint(Op):
output_storage[i][0] = inputs[i + 1] output_storage[i][0] = inputs[i + 1]
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
return [DisconnectedType()(), *output_gradients] return [disconnected_type(), *output_gradients]
def infer_shape(self, fgraph, inputs, input_shapes): def infer_shape(self, fgraph, inputs, input_shapes):
# Return the shape of every input but the condition (first input) # Return the shape of every input but the condition (first input)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from textwrap import indent from textwrap import indent
from pytensor.gradient import DisconnectedType from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
...@@ -89,7 +89,10 @@ class CheckAndRaise(COp): ...@@ -89,7 +89,10 @@ class CheckAndRaise(COp):
raise self.exc_type(self.msg) raise self.exc_type(self.msg)
def grad(self, input, output_gradients): def grad(self, input, output_gradients):
return output_gradients + [DisconnectedType()()] * (len(input) - 1) return [
*output_gradients,
*(disconnected_type() for _ in range(len(input) - 1)),
]
def connection_pattern(self, node): def connection_pattern(self, node):
return [[1]] + [[0]] * (len(node.inputs) - 1) return [[1]] + [[0]] * (len(node.inputs) - 1)
......
...@@ -22,7 +22,7 @@ import numpy as np ...@@ -22,7 +22,7 @@ import numpy as np
import pytensor import pytensor
from pytensor import printing from pytensor import printing
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.gradient import disconnected_type, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable, clone from pytensor.graph.basic import Apply, Constant, Variable, clone
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph from pytensor.graph.op import HasInnerGraph
...@@ -2426,13 +2426,13 @@ class Second(BinaryScalarOp): ...@@ -2426,13 +2426,13 @@ class Second(BinaryScalarOp):
(gz,) = gout (gz,) = gout
if y.type in continuous_types: if y.type in continuous_types:
# x is disconnected because the elements of x are not used # x is disconnected because the elements of x are not used
return DisconnectedType()(), gz return disconnected_type(), gz
else: else:
# when y is discrete, we assume the function can be extended # when y is discrete, we assume the function can be extended
# to deal with real-valued inputs by rounding them to the # to deal with real-valued inputs by rounding them to the
# nearest integer. f(x+eps) thus equals f(x) so the gradient # nearest integer. f(x+eps) thus equals f(x) so the gradient
# is zero, not disconnected or undefined # is zero, not disconnected or undefined
return DisconnectedType()(), y.zeros_like(dtype=config.floatX) return disconnected_type(), y.zeros_like(dtype=config.floatX)
second = Second(name="second") second = Second(name="second")
......
...@@ -63,7 +63,14 @@ from pytensor.compile.io import In, Out ...@@ -63,7 +63,14 @@ from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_mode from pytensor.compile.mode import Mode, get_mode
from pytensor.compile.profiling import register_profiler_printer from pytensor.compile.profiling import register_profiler_printer
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined from pytensor.gradient import (
DisconnectedType,
NullType,
Rop,
disconnected_type,
grad,
grad_undefined,
)
from pytensor.graph.basic import ( from pytensor.graph.basic import (
Apply, Apply,
Variable, Variable,
...@@ -3073,7 +3080,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3073,7 +3080,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
outputs = local_op(*outer_inputs, return_list=True) outputs = local_op(*outer_inputs, return_list=True)
# Re-order the gradients correctly # Re-order the gradients correctly
gradients = [DisconnectedType()()] gradients = [disconnected_type()] # n_steps is disconnected
offset = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot + n_sitsot_outs offset = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot + n_sitsot_outs
for p, (x, t) in enumerate( for p, (x, t) in enumerate(
...@@ -3098,7 +3105,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3098,7 +3105,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else: else:
gradients.append(x[::-1]) gradients.append(x[::-1])
elif t == "disconnected": elif t == "disconnected":
gradients.append(DisconnectedType()()) gradients.append(disconnected_type())
elif t == "through_untraced": elif t == "through_untraced":
gradients.append( gradients.append(
grad_undefined( grad_undefined(
...@@ -3126,7 +3133,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3126,7 +3133,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else: else:
gradients.append(x[::-1]) gradients.append(x[::-1])
elif t == "disconnected": elif t == "disconnected":
gradients.append(DisconnectedType()()) gradients.append(disconnected_type())
elif t == "through_untraced": elif t == "through_untraced":
gradients.append( gradients.append(
grad_undefined( grad_undefined(
...@@ -3149,7 +3156,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3149,7 +3156,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if not isinstance(dC_dout.type, DisconnectedType) and connected: if not isinstance(dC_dout.type, DisconnectedType) and connected:
disconnected = False disconnected = False
if disconnected: if disconnected:
gradients.append(DisconnectedType()()) gradients.append(disconnected_type())
else: else:
gradients.append( gradients.append(
grad_undefined( grad_undefined(
...@@ -3157,7 +3164,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3157,7 +3164,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
) )
gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)] gradients.extend(disconnected_type() for _ in range(info.n_nit_sot))
begin = end begin = end
end = begin + n_sitsot_outs end = begin + n_sitsot_outs
...@@ -3167,7 +3174,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3167,7 +3174,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if t == "connected": if t == "connected":
gradients.append(x[-1]) gradients.append(x[-1])
elif t == "disconnected": elif t == "disconnected":
gradients.append(DisconnectedType()()) gradients.append(disconnected_type())
elif t == "through_untraced": elif t == "through_untraced":
gradients.append( gradients.append(
grad_undefined( grad_undefined(
...@@ -3195,7 +3202,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3195,7 +3202,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
): ):
disconnected = False disconnected = False
if disconnected: if disconnected:
gradients[idx] = DisconnectedType()() gradients[idx] = disconnected_type()
return gradients return gradients
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
...@@ -18,7 +18,7 @@ import pytensor ...@@ -18,7 +18,7 @@ import pytensor
from pytensor import _as_symbolic, as_symbolic from pytensor import _as_symbolic, as_symbolic
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.gradient import DisconnectedType, disconnected_type, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.c.type import generic from pytensor.link.c.type import generic
...@@ -480,9 +480,9 @@ class CSM(Op): ...@@ -480,9 +480,9 @@ class CSM(Op):
) )
return [ return [
g_data, g_data,
DisconnectedType()(), disconnected_type(),
DisconnectedType()(), disconnected_type(),
DisconnectedType()(), disconnected_type(),
] ]
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
...@@ -1940,7 +1940,7 @@ class ConstructSparseFromList(Op): ...@@ -1940,7 +1940,7 @@ class ConstructSparseFromList(Op):
gx = g_output gx = g_output
gy = pytensor.tensor.subtensor.advanced_subtensor1(g_output, *idx_list) gy = pytensor.tensor.subtensor.advanced_subtensor1(g_output, *idx_list)
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy, *(disconnected_type() for _ in range(len(idx_list)))]
construct_sparse_from_list = ConstructSparseFromList() construct_sparse_from_list = ConstructSparseFromList()
...@@ -22,7 +22,7 @@ import pytensor.scalar.sharedvar ...@@ -22,7 +22,7 @@ import pytensor.scalar.sharedvar
from pytensor import config, printing from pytensor import config, printing
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.gradient import DisconnectedType, disconnected_type, grad_undefined
from pytensor.graph import RewriteDatabaseQuery from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.fg import FunctionGraph, Output
...@@ -1738,7 +1738,7 @@ class Alloc(COp): ...@@ -1738,7 +1738,7 @@ class Alloc(COp):
# the inputs that specify the shape. If you grow the # the inputs that specify the shape. If you grow the
# shape by epsilon, the existing elements do not # shape by epsilon, the existing elements do not
# change. # change.
return [gx] + [DisconnectedType()() for i in inputs[1:]] return [gx, *(disconnected_type() for _ in range(len(inputs) - 1))]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
...@@ -2277,7 +2277,7 @@ class Split(COp): ...@@ -2277,7 +2277,7 @@ class Split(COp):
return [ return [
join(axis, *new_g_outputs), join(axis, *new_g_outputs),
grad_undefined(self, 1, axis), grad_undefined(self, 1, axis),
DisconnectedType()(), disconnected_type(),
] ]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -3340,14 +3340,14 @@ class ARange(COp): ...@@ -3340,14 +3340,14 @@ class ARange(COp):
if self.dtype in discrete_dtypes: if self.dtype in discrete_dtypes:
return [ return [
start.zeros_like(dtype=config.floatX), start.zeros_like(dtype=config.floatX),
DisconnectedType()(), disconnected_type(),
step.zeros_like(dtype=config.floatX), step.zeros_like(dtype=config.floatX),
] ]
else: else:
num_steps_taken = outputs[0].shape[0] num_steps_taken = outputs[0].shape[0]
return [ return [
gz.sum(), gz.sum(),
DisconnectedType()(), disconnected_type(),
(gz * arange(num_steps_taken, dtype=self.dtype)).sum(), (gz * arange(num_steps_taken, dtype=self.dtype)).sum(),
] ]
...@@ -4374,7 +4374,7 @@ class AllocEmpty(COp): ...@@ -4374,7 +4374,7 @@ class AllocEmpty(COp):
return [[False] for i in node.inputs] return [[False] for i in node.inputs]
def grad(self, inputs, grads): def grad(self, inputs, grads):
return [DisconnectedType()() for i in inputs] return [disconnected_type() for _ in range(len(inputs))]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [zeros(inputs, self.dtype)] return [zeros(inputs, self.dtype)]
......
...@@ -8,7 +8,6 @@ from numpy.lib.array_utils import normalize_axis_index ...@@ -8,7 +8,6 @@ from numpy.lib.array_utils import normalize_axis_index
import pytensor import pytensor
import pytensor.scalar.basic as ps import pytensor.scalar.basic as ps
from pytensor.gradient import ( from pytensor.gradient import (
DisconnectedType,
_float_zeros_like, _float_zeros_like,
disconnected_type, disconnected_type,
grad_undefined, grad_undefined,
...@@ -716,7 +715,7 @@ class Repeat(Op): ...@@ -716,7 +715,7 @@ class Repeat(Op):
gx_transpose = ptb.zeros_like(x_transpose)[repeated_arange].inc(gz_transpose) gx_transpose = ptb.zeros_like(x_transpose)[repeated_arange].inc(gz_transpose)
gx = ptb.moveaxis(gx_transpose, 0, axis) gx = ptb.moveaxis(gx_transpose, 0, axis)
return [gx, DisconnectedType()()] return [gx, disconnected_type()]
def infer_shape(self, fgraph, node, ins_shapes): def infer_shape(self, fgraph, node, ins_shapes):
i0_shapes = ins_shapes[0] i0_shapes = ins_shapes[0]
......
import numpy as np import numpy as np
from pytensor.gradient import DisconnectedType from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import as_tensor_variable
...@@ -59,7 +59,7 @@ class RFFTOp(Op): ...@@ -59,7 +59,7 @@ class RFFTOp(Op):
+ [slice(None)] + [slice(None)]
) )
gout = set_subtensor(gout[idx], gout[idx] * 0.5) gout = set_subtensor(gout[idx], gout[idx] * 0.5)
return [irfft_op(gout, s), DisconnectedType()()] return [irfft_op(gout, s), disconnected_type()]
def connection_pattern(self, node): def connection_pattern(self, node):
# Specify that shape input parameter has no connection to graph and gradients. # Specify that shape input parameter has no connection to graph and gradients.
...@@ -121,7 +121,7 @@ class IRFFTOp(Op): ...@@ -121,7 +121,7 @@ class IRFFTOp(Op):
+ [slice(None)] + [slice(None)]
) )
gf = set_subtensor(gf[idx], gf[idx] * 2) gf = set_subtensor(gf[idx], gf[idx] * 2)
return [gf, DisconnectedType()()] return [gf, disconnected_type()]
def connection_pattern(self, node): def connection_pattern(self, node):
# Specify that shape input parameter has no connection to graph and gradients. # Specify that shape input parameter has no connection to graph and gradients.
......
...@@ -8,7 +8,7 @@ from numpy.lib.array_utils import normalize_axis_tuple ...@@ -8,7 +8,7 @@ from numpy.lib.array_utils import normalize_axis_tuple
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType, disconnected_type
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
...@@ -652,8 +652,8 @@ class SVD(Op): ...@@ -652,8 +652,8 @@ class SVD(Op):
] ]
if all(is_disconnected): if all(is_disconnected):
# This should never actually be reached by Pytensor -- the SVD Op should be pruned from the gradient # This should never actually be reached by Pytensor -- the SVD Op should be pruned from the gradient
# graph if its fully disconnected. It is included for completeness. # graph if it's fully disconnected. It is included for completeness.
return [DisconnectedType()()] # pragma: no cover return [disconnected_type()] # pragma: no cover
elif is_disconnected == [True, False, True]: elif is_disconnected == [True, False, True]:
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without # This is the same as the compute_uv = False, so we can drop back to that simpler computation, without
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
from numpy.lib._array_utils_impl import normalize_axis_index, normalize_axis_tuple from numpy.lib._array_utils_impl import normalize_axis_index, normalize_axis_tuple
from pytensor import Variable from pytensor import Variable
from pytensor.gradient import DisconnectedType from pytensor.gradient import disconnected_type
from pytensor.graph import Apply from pytensor.graph import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
...@@ -217,7 +217,7 @@ class SplitDims(Op): ...@@ -217,7 +217,7 @@ class SplitDims(Op):
n_axes = g_out.ndim - x.ndim + 1 n_axes = g_out.ndim - x.ndim + 1
axis_range = list(range(self.axis, self.axis + n_axes)) axis_range = list(range(self.axis, self.axis + n_axes))
return [join_dims(g_out, axis=axis_range), DisconnectedType()()] return [join_dims(g_out, axis=axis_range), disconnected_type()]
@_vectorize_node.register(SplitDims) @_vectorize_node.register(SplitDims)
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple from numpy.lib.array_utils import normalize_axis_tuple
import pytensor import pytensor
from pytensor.gradient import DisconnectedType from pytensor.gradient import disconnected_type
from pytensor.graph import Op from pytensor.graph import Op
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
...@@ -103,7 +103,7 @@ class Shape(COp): ...@@ -103,7 +103,7 @@ class Shape(COp):
# the elements of the tensor variable do not participate # the elements of the tensor variable do not participate
# in the computation of the shape, so they are not really # in the computation of the shape, so they are not really
# part of the graph # part of the graph
return [pytensor.gradient.DisconnectedType()()] return [disconnected_type()]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None] return [None]
...@@ -474,8 +474,9 @@ class SpecifyShape(COp): ...@@ -474,8 +474,9 @@ class SpecifyShape(COp):
def grad(self, inp, grads): def grad(self, inp, grads):
_x, *shape = inp _x, *shape = inp
(gz,) = grads (gz,) = grads
return [specify_shape(gz, shape)] + [ return [
pytensor.gradient.DisconnectedType()() for _ in range(len(shape)) specify_shape(gz, shape),
*(disconnected_type() for _ in range(len(shape))),
] ]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -725,7 +726,7 @@ class Reshape(COp): ...@@ -725,7 +726,7 @@ class Reshape(COp):
def grad(self, inp, grads): def grad(self, inp, grads):
x, _shp = inp x, _shp = inp
(g_out,) = grads (g_out,) = grads
return [reshape(g_out, shape(x), ndim=x.ndim), DisconnectedType()()] return [reshape(g_out, shape(x), ndim=x.ndim), disconnected_type()]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
from numpy import convolve as numpy_convolve from numpy import convolve as numpy_convolve
from scipy.signal import convolve as scipy_convolve from scipy.signal import convolve as scipy_convolve
from pytensor.gradient import DisconnectedType from pytensor.gradient import disconnected_type
from pytensor.graph import Apply, Constant from pytensor.graph import Apply, Constant
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
...@@ -109,7 +109,7 @@ class AbstractConvolveNd: ...@@ -109,7 +109,7 @@ class AbstractConvolveNd:
return [ return [
self(grad, flip(in2), full_mode_in1_bar), self(grad, flip(in2), full_mode_in1_bar),
self(grad, flip(in1), full_mode_in2_bar), self(grad, flip(in1), full_mode_in2_bar),
DisconnectedType()(), disconnected_type(),
] ]
......
...@@ -11,7 +11,7 @@ from scipy.linalg import get_lapack_funcs ...@@ -11,7 +11,7 @@ from scipy.linalg import get_lapack_funcs
import pytensor import pytensor
from pytensor import ifelse from pytensor import ifelse
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType, disconnected_type
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.raise_op import Assert, CheckAndRaise from pytensor.raise_op import Assert, CheckAndRaise
...@@ -1966,7 +1966,7 @@ class QR(Op): ...@@ -1966,7 +1966,7 @@ class QR(Op):
] ]
if all(is_disconnected): if all(is_disconnected):
# This should never be reached by Pytensor # This should never be reached by Pytensor
return [DisconnectedType()()] # pragma: no cover return [disconnected_type()] # pragma: no cover
for disconnected, output_grad, output in zip( for disconnected, output_grad, output in zip(
is_disconnected, output_grads, [Q, R], strict=True is_disconnected, output_grads, [Q, R], strict=True
......
...@@ -11,7 +11,7 @@ from numpy.lib.array_utils import normalize_axis_tuple ...@@ -11,7 +11,7 @@ from numpy.lib.array_utils import normalize_axis_tuple
import pytensor import pytensor
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
...@@ -988,7 +988,7 @@ class Subtensor(COp): ...@@ -988,7 +988,7 @@ class Subtensor(COp):
# set subtensor here at: # set subtensor here at:
# pytensor/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor() # pytensor/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor()
first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *rest) first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *rest)
return [first] + [DisconnectedType()()] * len(rest) return [first, *(disconnected_type() for _ in range(len(rest)))]
def connection_pattern(self, node): def connection_pattern(self, node):
rval = [[True], *([False] for _ in node.inputs[1:])] rval = [[True], *([False] for _ in node.inputs[1:])]
...@@ -2023,7 +2023,7 @@ class IncSubtensor(COp): ...@@ -2023,7 +2023,7 @@ class IncSubtensor(COp):
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list) gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
gy = _sum_grad_over_bcasted_dims(y, gy) gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy, *(disconnected_type() for _ in range(len(idx_list)))]
class IncSubtensorPrinter(SubtensorPrinter): class IncSubtensorPrinter(SubtensorPrinter):
...@@ -2135,7 +2135,7 @@ class AdvancedSubtensor1(COp): ...@@ -2135,7 +2135,7 @@ class AdvancedSubtensor1(COp):
" from a tensor with ndim != 2. ndim is " + str(x.type.ndim) " from a tensor with ndim != 2. ndim is " + str(x.type.ndim)
) )
rval1 = [pytensor.sparse.construct_sparse_from_list(x, gz, ilist)] rval1 = pytensor.sparse.construct_sparse_from_list(x, gz, ilist)
else: else:
if x.dtype in discrete_dtypes: if x.dtype in discrete_dtypes:
# The output dtype is the same as x # The output dtype is the same as x
...@@ -2144,8 +2144,8 @@ class AdvancedSubtensor1(COp): ...@@ -2144,8 +2144,8 @@ class AdvancedSubtensor1(COp):
raise NotImplementedError("No support for complex grad yet") raise NotImplementedError("No support for complex grad yet")
else: else:
gx = x.zeros_like() gx = x.zeros_like()
rval1 = [advanced_inc_subtensor1(gx, gz, ilist)] rval1 = advanced_inc_subtensor1(gx, gz, ilist)
return rval1 + [DisconnectedType()()] * (len(inputs) - 1) return [rval1, *(disconnected_type() for _ in range(len(inputs) - 1))]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
...@@ -2519,7 +2519,7 @@ class AdvancedIncSubtensor1(COp): ...@@ -2519,7 +2519,7 @@ class AdvancedIncSubtensor1(COp):
gy = advanced_subtensor1(g_output, idx_list) gy = advanced_subtensor1(g_output, idx_list)
gy = _sum_grad_over_bcasted_dims(y, gy) gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy, DisconnectedType()()] return [gx, gy, disconnected_type()]
advanced_inc_subtensor1 = AdvancedIncSubtensor1() advanced_inc_subtensor1 = AdvancedIncSubtensor1()
...@@ -2771,9 +2771,10 @@ class AdvancedSubtensor(Op): ...@@ -2771,9 +2771,10 @@ class AdvancedSubtensor(Op):
else: else:
gx = x.zeros_like() gx = x.zeros_like()
rest = inputs[1:] rest = inputs[1:]
return [advanced_inc_subtensor(gx, gz, *rest)] + [DisconnectedType()()] * len( return [
rest advanced_inc_subtensor(gx, gz, *rest),
) *(disconnected_type() for _ in range(len(rest))),
]
@staticmethod @staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool: def non_contiguous_adv_indexing(node: Apply) -> bool:
...@@ -2933,7 +2934,7 @@ class AdvancedIncSubtensor(Op): ...@@ -2933,7 +2934,7 @@ class AdvancedIncSubtensor(Op):
# Make sure to sum gy over the dimensions of y that have been # Make sure to sum gy over the dimensions of y that have been
# added or broadcasted # added or broadcasted
gy = _sum_grad_over_bcasted_dims(y, gy) gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + [DisconnectedType()() for _ in idxs] return [gx, gy, *(disconnected_type() for _ in range(len(idxs)))]
@staticmethod @staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool: def non_contiguous_adv_indexing(node: Apply) -> bool:
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
import pytensor import pytensor
from pytensor import _as_symbolic from pytensor import _as_symbolic
from pytensor.gradient import DisconnectedType from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.c.type import Generic, Type from pytensor.link.c.type import Generic, Type
...@@ -44,7 +44,7 @@ class MakeSlice(Op): ...@@ -44,7 +44,7 @@ class MakeSlice(Op):
out[0] = slice(*inp) out[0] = slice(*inp)
def grad(self, inputs, grads): def grad(self, inputs, grads):
return [DisconnectedType()() for i in inputs] return [disconnected_type() for _ in range(len(inputs))]
make_slice = MakeSlice() make_slice = MakeSlice()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论