提交 2faa56a4 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Do not redefine DisconnectedType everytime

上级 d8b51df8
......@@ -495,7 +495,7 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o
from pytensor.graph.op import Op
from pytensor.graph.basic import Apply
from pytensor.gradient import DisconnectedType
from pytensor.gradient import DisconnectedType, disconnected_type
class TransposeAndSumOp(Op):
__props__ = ()
......@@ -539,13 +539,13 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o
out1_grad, out2_grad = output_grads
if isinstance(out1_grad.type, DisconnectedType):
x_grad = DisconnectedType()()
x_grad = disconnected_type()
else:
# Transpose the last two dimensions of the output gradient
x_grad = pt.swapaxes(out1_grad, -1, -2)
if isinstance(out2_grad.type, DisconnectedType):
y_grad = DisconnectedType()()
y_grad = disconnected_type()
else:
# Broadcast the output gradient to the same shape as y
y_grad = pt.broadcast_to(pt.expand_dims(out2_grad, -1), y.shape)
......
import numpy as np
from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.tensor.basic import as_tensor_variable
......@@ -142,7 +142,7 @@ class PdbBreakpoint(Op):
output_storage[i][0] = inputs[i + 1]
def grad(self, inputs, output_gradients):
return [DisconnectedType()(), *output_gradients]
return [disconnected_type(), *output_gradients]
def infer_shape(self, fgraph, inputs, input_shapes):
# Return the shape of every input but the condition (first input)
......
......@@ -2,7 +2,7 @@
from textwrap import indent
from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
......@@ -89,7 +89,10 @@ class CheckAndRaise(COp):
raise self.exc_type(self.msg)
def grad(self, input, output_gradients):
return output_gradients + [DisconnectedType()()] * (len(input) - 1)
return [
*output_gradients,
*(disconnected_type() for _ in range(len(input) - 1)),
]
def connection_pattern(self, node):
return [[1]] + [[0]] * (len(node.inputs) - 1)
......
......@@ -22,7 +22,7 @@ import numpy as np
import pytensor
from pytensor import printing
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.gradient import disconnected_type, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable, clone
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph
......@@ -2426,13 +2426,13 @@ class Second(BinaryScalarOp):
(gz,) = gout
if y.type in continuous_types:
# x is disconnected because the elements of x are not used
return DisconnectedType()(), gz
return disconnected_type(), gz
else:
# when y is discrete, we assume the function can be extended
# to deal with real-valued inputs by rounding them to the
# nearest integer. f(x+eps) thus equals f(x) so the gradient
# is zero, not disconnected or undefined
return DisconnectedType()(), y.zeros_like(dtype=config.floatX)
return disconnected_type(), y.zeros_like(dtype=config.floatX)
second = Second(name="second")
......
......@@ -63,7 +63,14 @@ from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_mode
from pytensor.compile.profiling import register_profiler_printer
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
from pytensor.gradient import (
DisconnectedType,
NullType,
Rop,
disconnected_type,
grad,
grad_undefined,
)
from pytensor.graph.basic import (
Apply,
Variable,
......@@ -3073,7 +3080,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
outputs = local_op(*outer_inputs, return_list=True)
# Re-order the gradients correctly
gradients = [DisconnectedType()()]
gradients = [disconnected_type()] # n_steps is disconnected
offset = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot + n_sitsot_outs
for p, (x, t) in enumerate(
......@@ -3098,7 +3105,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
gradients.append(x[::-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
gradients.append(disconnected_type())
elif t == "through_untraced":
gradients.append(
grad_undefined(
......@@ -3126,7 +3133,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
gradients.append(x[::-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
gradients.append(disconnected_type())
elif t == "through_untraced":
gradients.append(
grad_undefined(
......@@ -3149,7 +3156,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if not isinstance(dC_dout.type, DisconnectedType) and connected:
disconnected = False
if disconnected:
gradients.append(DisconnectedType()())
gradients.append(disconnected_type())
else:
gradients.append(
grad_undefined(
......@@ -3157,7 +3164,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
)
gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)]
gradients.extend(disconnected_type() for _ in range(info.n_nit_sot))
begin = end
end = begin + n_sitsot_outs
......@@ -3167,7 +3174,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if t == "connected":
gradients.append(x[-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
gradients.append(disconnected_type())
elif t == "through_untraced":
gradients.append(
grad_undefined(
......@@ -3195,7 +3202,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
):
disconnected = False
if disconnected:
gradients[idx] = DisconnectedType()()
gradients[idx] = disconnected_type()
return gradients
def R_op(self, inputs, eval_points):
......
......@@ -18,7 +18,7 @@ import pytensor
from pytensor import _as_symbolic, as_symbolic
from pytensor import scalar as ps
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.gradient import DisconnectedType, disconnected_type, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.link.c.type import generic
......@@ -480,9 +480,9 @@ class CSM(Op):
)
return [
g_data,
DisconnectedType()(),
DisconnectedType()(),
DisconnectedType()(),
disconnected_type(),
disconnected_type(),
disconnected_type(),
]
def infer_shape(self, fgraph, node, shapes):
......@@ -1940,7 +1940,7 @@ class ConstructSparseFromList(Op):
gx = g_output
gy = pytensor.tensor.subtensor.advanced_subtensor1(g_output, *idx_list)
return [gx, gy] + [DisconnectedType()()] * len(idx_list)
return [gx, gy, *(disconnected_type() for _ in range(len(idx_list)))]
construct_sparse_from_list = ConstructSparseFromList()
......@@ -22,7 +22,7 @@ import pytensor.scalar.sharedvar
from pytensor import config, printing
from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.gradient import DisconnectedType, disconnected_type, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
from pytensor.graph.fg import FunctionGraph, Output
......@@ -1738,7 +1738,7 @@ class Alloc(COp):
# the inputs that specify the shape. If you grow the
# shape by epsilon, the existing elements do not
# change.
return [gx] + [DisconnectedType()() for i in inputs[1:]]
return [gx, *(disconnected_type() for _ in range(len(inputs) - 1))]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
......@@ -2277,7 +2277,7 @@ class Split(COp):
return [
join(axis, *new_g_outputs),
grad_undefined(self, 1, axis),
DisconnectedType()(),
disconnected_type(),
]
def R_op(self, inputs, eval_points):
......@@ -3340,14 +3340,14 @@ class ARange(COp):
if self.dtype in discrete_dtypes:
return [
start.zeros_like(dtype=config.floatX),
DisconnectedType()(),
disconnected_type(),
step.zeros_like(dtype=config.floatX),
]
else:
num_steps_taken = outputs[0].shape[0]
return [
gz.sum(),
DisconnectedType()(),
disconnected_type(),
(gz * arange(num_steps_taken, dtype=self.dtype)).sum(),
]
......@@ -4374,7 +4374,7 @@ class AllocEmpty(COp):
return [[False] for i in node.inputs]
def grad(self, inputs, grads):
return [DisconnectedType()() for i in inputs]
return [disconnected_type() for _ in range(len(inputs))]
def R_op(self, inputs, eval_points):
return [zeros(inputs, self.dtype)]
......
......@@ -8,7 +8,6 @@ from numpy.lib.array_utils import normalize_axis_index
import pytensor
import pytensor.scalar.basic as ps
from pytensor.gradient import (
DisconnectedType,
_float_zeros_like,
disconnected_type,
grad_undefined,
......@@ -716,7 +715,7 @@ class Repeat(Op):
gx_transpose = ptb.zeros_like(x_transpose)[repeated_arange].inc(gz_transpose)
gx = ptb.moveaxis(gx_transpose, 0, axis)
return [gx, DisconnectedType()()]
return [gx, disconnected_type()]
def infer_shape(self, fgraph, node, ins_shapes):
i0_shapes = ins_shapes[0]
......
import numpy as np
from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor.basic import as_tensor_variable
......@@ -59,7 +59,7 @@ class RFFTOp(Op):
+ [slice(None)]
)
gout = set_subtensor(gout[idx], gout[idx] * 0.5)
return [irfft_op(gout, s), DisconnectedType()()]
return [irfft_op(gout, s), disconnected_type()]
def connection_pattern(self, node):
# Specify that shape input parameter has no connection to graph and gradients.
......@@ -121,7 +121,7 @@ class IRFFTOp(Op):
+ [slice(None)]
)
gf = set_subtensor(gf[idx], gf[idx] * 2)
return [gf, DisconnectedType()()]
return [gf, disconnected_type()]
def connection_pattern(self, node):
# Specify that shape input parameter has no connection to graph and gradients.
......
......@@ -8,7 +8,7 @@ from numpy.lib.array_utils import normalize_axis_tuple
from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.gradient import DisconnectedType, disconnected_type
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import TensorLike
......@@ -652,8 +652,8 @@ class SVD(Op):
]
if all(is_disconnected):
# This should never actually be reached by Pytensor -- the SVD Op should be pruned from the gradient
# graph if its fully disconnected. It is included for completeness.
return [DisconnectedType()()] # pragma: no cover
# graph if it's fully disconnected. It is included for completeness.
return [disconnected_type()] # pragma: no cover
elif is_disconnected == [True, False, True]:
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without
......
......@@ -6,7 +6,7 @@ import numpy as np
from numpy.lib._array_utils_impl import normalize_axis_index, normalize_axis_tuple
from pytensor import Variable
from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph import Apply
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
......@@ -217,7 +217,7 @@ class SplitDims(Op):
n_axes = g_out.ndim - x.ndim + 1
axis_range = list(range(self.axis, self.axis + n_axes))
return [join_dims(g_out, axis=axis_range), DisconnectedType()()]
return [join_dims(g_out, axis=axis_range), disconnected_type()]
@_vectorize_node.register(SplitDims)
......
......@@ -10,7 +10,7 @@ import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
import pytensor
from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph import Op
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node
......@@ -103,7 +103,7 @@ class Shape(COp):
# the elements of the tensor variable do not participate
# in the computation of the shape, so they are not really
# part of the graph
return [pytensor.gradient.DisconnectedType()()]
return [disconnected_type()]
def R_op(self, inputs, eval_points):
return [None]
......@@ -474,8 +474,9 @@ class SpecifyShape(COp):
def grad(self, inp, grads):
_x, *shape = inp
(gz,) = grads
return [specify_shape(gz, shape)] + [
pytensor.gradient.DisconnectedType()() for _ in range(len(shape))
return [
specify_shape(gz, shape),
*(disconnected_type() for _ in range(len(shape))),
]
def R_op(self, inputs, eval_points):
......@@ -725,7 +726,7 @@ class Reshape(COp):
def grad(self, inp, grads):
x, _shp = inp
(g_out,) = grads
return [reshape(g_out, shape(x), ndim=x.ndim), DisconnectedType()()]
return [reshape(g_out, shape(x), ndim=x.ndim), disconnected_type()]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
......
......@@ -5,7 +5,7 @@ import numpy as np
from numpy import convolve as numpy_convolve
from scipy.signal import convolve as scipy_convolve
from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph import Apply, Constant
from pytensor.graph.op import Op
from pytensor.link.c.op import COp
......@@ -109,7 +109,7 @@ class AbstractConvolveNd:
return [
self(grad, flip(in2), full_mode_in1_bar),
self(grad, flip(in1), full_mode_in2_bar),
DisconnectedType()(),
disconnected_type(),
]
......
......@@ -11,7 +11,7 @@ from scipy.linalg import get_lapack_funcs
import pytensor
from pytensor import ifelse
from pytensor import tensor as pt
from pytensor.gradient import DisconnectedType
from pytensor.gradient import DisconnectedType, disconnected_type
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.raise_op import Assert, CheckAndRaise
......@@ -1966,7 +1966,7 @@ class QR(Op):
]
if all(is_disconnected):
# This should never be reached by Pytensor
return [DisconnectedType()()] # pragma: no cover
return [disconnected_type()] # pragma: no cover
for disconnected, output_grad, output in zip(
is_disconnected, output_grads, [Q, R], strict=True
......
......@@ -11,7 +11,7 @@ from numpy.lib.array_utils import normalize_axis_tuple
import pytensor
from pytensor import scalar as ps
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
......@@ -988,7 +988,7 @@ class Subtensor(COp):
# set subtensor here at:
# pytensor/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor()
first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *rest)
return [first] + [DisconnectedType()()] * len(rest)
return [first, *(disconnected_type() for _ in range(len(rest)))]
def connection_pattern(self, node):
rval = [[True], *([False] for _ in node.inputs[1:])]
......@@ -2023,7 +2023,7 @@ class IncSubtensor(COp):
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + [DisconnectedType()()] * len(idx_list)
return [gx, gy, *(disconnected_type() for _ in range(len(idx_list)))]
class IncSubtensorPrinter(SubtensorPrinter):
......@@ -2135,7 +2135,7 @@ class AdvancedSubtensor1(COp):
" from a tensor with ndim != 2. ndim is " + str(x.type.ndim)
)
rval1 = [pytensor.sparse.construct_sparse_from_list(x, gz, ilist)]
rval1 = pytensor.sparse.construct_sparse_from_list(x, gz, ilist)
else:
if x.dtype in discrete_dtypes:
# The output dtype is the same as x
......@@ -2144,8 +2144,8 @@ class AdvancedSubtensor1(COp):
raise NotImplementedError("No support for complex grad yet")
else:
gx = x.zeros_like()
rval1 = [advanced_inc_subtensor1(gx, gz, ilist)]
return rval1 + [DisconnectedType()()] * (len(inputs) - 1)
rval1 = advanced_inc_subtensor1(gx, gz, ilist)
return [rval1, *(disconnected_type() for _ in range(len(inputs) - 1))]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
......@@ -2519,7 +2519,7 @@ class AdvancedIncSubtensor1(COp):
gy = advanced_subtensor1(g_output, idx_list)
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy, DisconnectedType()()]
return [gx, gy, disconnected_type()]
advanced_inc_subtensor1 = AdvancedIncSubtensor1()
......@@ -2771,9 +2771,10 @@ class AdvancedSubtensor(Op):
else:
gx = x.zeros_like()
rest = inputs[1:]
return [advanced_inc_subtensor(gx, gz, *rest)] + [DisconnectedType()()] * len(
rest
)
return [
advanced_inc_subtensor(gx, gz, *rest),
*(disconnected_type() for _ in range(len(rest))),
]
@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
......@@ -2933,7 +2934,7 @@ class AdvancedIncSubtensor(Op):
# Make sure to sum gy over the dimensions of y that have been
# added or broadcasted
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + [DisconnectedType()() for _ in idxs]
return [gx, gy, *(disconnected_type() for _ in range(len(idxs)))]
@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
......
......@@ -6,7 +6,7 @@ import numpy as np
import pytensor
from pytensor import _as_symbolic
from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.link.c.type import Generic, Type
......@@ -44,7 +44,7 @@ class MakeSlice(Op):
out[0] = slice(*inp)
def grad(self, inputs, grads):
return [DisconnectedType()() for i in inputs]
return [disconnected_type() for _ in range(len(inputs))]
make_slice = MakeSlice()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论