提交 2eb8fca2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor code that handles disconnected L_op/R_op outputs in OpFromGraph

上级 be799d8f
"""Define new Ops from existing Ops""" """Define new Ops from existing Ops"""
import warnings import warnings
from collections import OrderedDict
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from copy import copy from copy import copy
from functools import partial from functools import partial
from typing import Union, cast from typing import Union, cast
import pytensor.tensor as pt
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.mode import optdb from pytensor.compile.mode import optdb
...@@ -251,57 +249,6 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -251,57 +249,6 @@ class OpFromGraph(Op, HasInnerGraph):
""" """
TYPE_ERR_MSG = (
"L_op/gradient override should be (single or list of)"
"None | OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got %s"
)
STYPE_ERR_MSG = (
"Overriding Variable instance can only have type"
" of DisconnectedType or NullType, got %s"
)
LOP_TYPE_ERR_MSG = 'L_op type can only be "grad" or "lop", got %s.'
OV_INP_LEN_ERR_MSG = "expect overrider with %d inputs, got %d"
@staticmethod
def _filter_grad_var(grad, inp):
# Returns (filtered_var, overrider_var)
# Args:
# grad: gradient Variable
# inp: the corresponding input of gradient Variable
#
# a grad() call could return instance of NullType() or DisconnectedType()
# which cannot be directly used in OfG
#
# Since we always use an OfG instance as self._lop_op, the current
# workaround is to "remember" the special cases of the gradient and
# replace them after self._lop_op is called.
#
# This helper function changes invalid types into a filtered_var,
# and provides a overrider_var to be replaced at grad() call
#
# For now, this converts NullType or DisconnectedType into zeros_like.
# other types are unmodified: overrider_var -> None
if isinstance(grad.type, NullType | DisconnectedType):
if hasattr(inp, "zeros_like"):
return inp.zeros_like(), grad
else:
return pt.constant(0.0), grad
else:
return grad, None
@staticmethod
def _filter_rop_var(inpJ, out):
# mostly similar to _filter_grad_var
if isinstance(inpJ.type, NullType):
return out.zeros_like(), inpJ
if isinstance(inpJ.type, DisconnectedType):
# since R_op does not have DisconnectedType yet, we will just
# make them zeros.
return out.zeros_like(), None
else:
return inpJ, None
def __init__( def __init__(
self, self,
inputs: list[Variable], inputs: list[Variable],
...@@ -322,8 +269,10 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -322,8 +269,10 @@ class OpFromGraph(Op, HasInnerGraph):
---------- ----------
inputs inputs
The inputs to the graph. The inputs to the graph.
outputs outputs
The outputs to the graph. The outputs to the graph.
inline inline
Defaults to ``False`` Defaults to ``False``
...@@ -332,6 +281,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -332,6 +281,7 @@ class OpFromGraph(Op, HasInnerGraph):
graph but rather its internal graph. graph but rather its internal graph.
``False`` : will use a pre-compiled function inside. ``False`` : will use a pre-compiled function inside.
grad_overrides grad_overrides
Defaults to ``None``. Defaults to ``None``.
This argument is mutually exclusive with ``lop_overrides``. This argument is mutually exclusive with ``lop_overrides``.
...@@ -345,6 +295,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -345,6 +295,7 @@ class OpFromGraph(Op, HasInnerGraph):
`callable`: Should take two args: ``inputs`` and ``output_grads``. `callable`: Should take two args: ``inputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable `. Each argument is expected to be a list of :class:`Variable `.
Must return list of :class:`Variable `. Must return list of :class:`Variable `.
lop_overrides lop_overrides
Defaults to ``None``. Defaults to ``None``.
...@@ -364,10 +315,6 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -364,10 +315,6 @@ class OpFromGraph(Op, HasInnerGraph):
Each argument is expected to be a list of :class:`Variable`. Each argument is expected to be a list of :class:`Variable`.
Must return list of :class:`Variable`. Must return list of :class:`Variable`.
`NullType` instance: Treat as non-differentiable
`DisconnectedType` instance: Treat as disconnected gradient,
numerically gives zero
``list``: Each `OpFromGraph`/callable must return a single ``list``: Each `OpFromGraph`/callable must return a single
:class:`Variable`. Each list element corresponds to gradient of :class:`Variable`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs. a specific input, length of list must be equal to number of inputs.
...@@ -387,10 +334,6 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -387,10 +334,6 @@ class OpFromGraph(Op, HasInnerGraph):
Each argument is expected to be a list of :class:`Variable`. Must Each argument is expected to be a list of :class:`Variable`. Must
return list of :class:`Variable`. return list of :class:`Variable`.
`NullType` instance: Treat as non-differentiable `DisconnectedType`
instance: Treat as zero since `DisconnectedType` is not yet supported
in :meth:`Op.R_op`.
``list``: ``list``:
Each :class:`OpFromGraph`/callable must return a single Each :class:`OpFromGraph`/callable must return a single
:class:`Variable <pytensor.graph.basic.Variable>`. Each list element :class:`Variable <pytensor.graph.basic.Variable>`. Each list element
...@@ -398,12 +341,15 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -398,12 +341,15 @@ class OpFromGraph(Op, HasInnerGraph):
must be equal to number of outputs. connection_pattern If not must be equal to number of outputs. connection_pattern If not
``None``, this will be used as the connection_pattern for this ``None``, this will be used as the connection_pattern for this
:class:`Op`. :class:`Op`.
strict: bool, default False strict: bool, default False
If true, it raises when any variables needed to compute the inner graph If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with are not provided as explici inputs. This can only happen for graphs with
shared variables. shared variables.
name name
A name for debugging purposes. A name for debugging purposes.
kwargs kwargs
Check :func:`pytensor.function` for more arguments, only works when not Check :func:`pytensor.function` for more arguments, only works when not
inline. inline.
...@@ -460,26 +406,19 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -460,26 +406,19 @@ class OpFromGraph(Op, HasInnerGraph):
self.grad_overrides = grad_overrides self.grad_overrides = grad_overrides
self.rop_overrides = rop_overrides self.rop_overrides = rop_overrides
if lop_overrides is not None: self._lop_op_interface = True
if grad_overrides is not None: if grad_overrides is not None:
if lop_overrides is not None:
raise ValueError( raise ValueError(
"lop_overrides and grad_overrides are mutually exclusive" "lop_overrides and grad_overrides are mutually exclusive"
) )
else:
self.set_lop_overrides(lop_overrides)
self._lop_type = "lop"
elif grad_overrides is not None:
warnings.warn( warnings.warn(
"grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future.", "grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future.",
FutureWarning, FutureWarning,
) )
self.set_lop_overrides(grad_overrides) self._lop_op_interface = False
self._lop_type = "grad" self._lop_op_cache: Callable | None = None
else: self._rop_op_cache: Callable | None = None
self.set_lop_overrides(None)
self._lop_type = "lop"
self.set_rop_overrides(rop_overrides)
self._connection_pattern = connection_pattern self._connection_pattern = connection_pattern
...@@ -501,307 +440,224 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -501,307 +440,224 @@ class OpFromGraph(Op, HasInnerGraph):
is_inline = self.is_inline is_inline = self.is_inline
return "{name}{{inline={is_inline}}}".format(**locals()) return "{name}{{inline={is_inline}}}".format(**locals())
def _combine_list_overrides(self, default_outs, custom_outs, callable_args):
"""Combines default and custom overrides into a single list of outputs."""
default_out_iter = iter(default_outs)
combined_outs = []
for custom_out in custom_outs:
if custom_out is None:
combined_outs.append(next(default_out_iter))
elif isinstance(custom_out, Variable):
if not isinstance(custom_out.type, NullType | DisconnectedType):
raise ValueError(
f"Override list can only contain NullType or DisconnectedType Variable instances, got {custom_out.type}"
)
combined_outs.append(custom_out)
elif callable(custom_out):
combined_outs.append(custom_out(*callable_args))
else:
raise ValueError(
f"Override list should contain None, Variable or callable, got {type(custom_out)}"
)
return combined_outs
def _call_custom_override(self, op_overrides, callable_args, nout):
"""Calls custom override function and provides informative error messages."""
if not callable(op_overrides):
raise TypeError(
f"L_op/R_op override should be None, a list or a Callable, got {type(op_overrides)}"
)
outputs = op_overrides(*callable_args)
if not isinstance(outputs, list):
raise TypeError(
f"Lop/Rop overriding function should return a list, got {type(outputs)}"
)
if len(outputs) != nout:
raise ValueError(
f"Lop/Rop overriding function {self.rop_overrides} should return "
f"a list of {nout} outputs, got {len(outputs)}"
)
return outputs
@config.change_flags(compute_test_value="off") @config.change_flags(compute_test_value="off")
def _recompute_lop_op(self): def _build_and_cache_lop_op(self) -> Callable:
""" """converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance.
converts self._lop_op from user supplied form to type(self) instance
Results are cached in self._lop_op_cache
""" """
local_inputs = self.inner_inputs if self._lop_op_cache is not None:
local_outputs = self.inner_outputs return self._lop_op_cache
inp_len = len(local_inputs)
lop_op = self._lop_op inner_inputs = self.inner_inputs
inner_outputs = self.inner_outputs
if isinstance(lop_op, OpFromGraph): nin = len(inner_inputs)
if self._lop_op_is_cached: lop_overrides = (
return self.lop_overrides if self._lop_op_interface else self.grad_overrides
assert self._lop_type in ("lop", "grad"), ( )
self.LOP_TYPE_ERR_MSG % self._lop_type
) if isinstance(lop_overrides, OpFromGraph):
if self._lop_type == "grad": if self._lop_op_interface:
needed_ninps = inp_len + len(local_outputs) self._lop_op_cache = lop_overrides
ninps = len(lop_op.inner_inputs) lop_overrides.kwargs["on_unused_input"] = "ignore"
if needed_ninps != ninps: return lop_overrides
raise ValueError(self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
# make a wrapper callable else:
# We need to add a wrapper for the different input signature
def lop_op(inps, grads): # TODO: Remove this once the grad interface is gone
return self._lop_op(*(inps + grads)) def lop_overrides(inps, grads):
return self.grad_overrides(*inps, *grads)
elif self._lop_type == "lop":
# OfG can be directly used in L_op format
needed_ninps = inp_len + 2 * len(local_outputs)
ninps = len(lop_op.inner_inputs)
if needed_ninps != ninps:
raise ValueError(self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
self._lop_op_is_cached = True
self._lop_op_stypes_l = [None] * inp_len
self._lop_op.kwargs["on_unused_input"] = "ignore"
return
output_grads = [out_t() for out_t in self.output_types] output_grads = [out_t() for out_t in self.output_types]
fn_grad = partial( fn_grad = partial(
grad, grad,
cost=None, cost=None,
disconnected_inputs="ignore", disconnected_inputs="ignore",
return_disconnected="Disconnected", return_disconnected="disconnected",
null_gradients="return", null_gradients="return",
known_grads=OrderedDict(zip(local_outputs, output_grads)), known_grads=dict(zip(inner_outputs, output_grads)),
) )
assert self._lop_type in ("lop", "grad"), self.LOP_TYPE_ERR_MSG % self._lop_type if self._lop_op_interface:
if self._lop_type == "lop": callable_args = (inner_inputs, inner_outputs, output_grads)
callable_args = (local_inputs, local_outputs, output_grads) else:
elif self._lop_type == "grad": callable_args = (inner_inputs, output_grads)
callable_args = (local_inputs, output_grads)
# we need to convert _lop_op into an OfG instance # we need to convert _lop_op into an OfG instance
if lop_op is None: if lop_overrides is None:
gdefaults_l = fn_grad(wrt=local_inputs) input_grads = fn_grad(wrt=inner_inputs)
all_grads_l, all_grads_ov_l = zip( elif isinstance(lop_overrides, list):
*[ custom_input_grads = lop_overrides
OpFromGraph._filter_grad_var(grad, inp) if len(custom_input_grads) != nin:
for grad, inp in zip(gdefaults_l, local_inputs)
]
)
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
elif isinstance(lop_op, list):
goverrides_l = lop_op
if len(goverrides_l) != inp_len:
raise ValueError( raise ValueError(
f"Need to override {int(inp_len)} gradients, got {len(goverrides_l)}", f"Need to override {nin} gradients, got {len(custom_input_grads)}",
goverrides_l, custom_input_grads,
) )
# compute non-overriding downsteam grads from upstreams grads # compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore' # it's normal some input may be disconnected, thus the 'ignore'
wrt_l = [lin for lin, gov in zip(local_inputs, goverrides_l) if gov is None] wrt = [
gdefaults = iter(fn_grad(wrt=wrt_l) if wrt_l else []) lin for lin, gov in zip(inner_inputs, custom_input_grads) if gov is None
# combine overriding gradients ]
all_grads_l = [] default_input_grads = fn_grad(wrt=wrt) if wrt else []
all_grads_ov_l = [] input_grads = self._combine_list_overrides(
for inp, fn_gov in zip(local_inputs, goverrides_l): default_input_grads, custom_input_grads, callable_args
if fn_gov is None:
gnext, gnext_ov = OpFromGraph._filter_grad_var(next(gdefaults), inp)
all_grads_l.append(gnext)
all_grads_ov_l.append(gnext_ov)
elif isinstance(fn_gov, Variable):
if isinstance(fn_gov.type, DisconnectedType | NullType):
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(fn_gov.type())
else:
raise ValueError(self.STYPE_ERR_MSG % fn_gov.type)
else:
if not callable(fn_gov):
raise TypeError(self.TYPE_ERR_MSG % fn_gov)
gov, gov_ov = OpFromGraph._filter_grad_var(
fn_gov(*callable_args), inp
)
all_grads_l.append(gov)
all_grads_ov_l.append(gov_ov)
else:
# callable case
if not callable(lop_op):
raise TypeError(self.TYPE_ERR_MSG % lop_op)
goverrides_l = lop_op(*callable_args)
if not isinstance(goverrides_l, list):
raise TypeError(
"Gradient/L_op overriding function should return a list, "
f'got "{type(goverrides_l)}"'
)
all_grads_l, all_grads_ov_l = zip(
*[
OpFromGraph._filter_grad_var(grad, inp)
for grad, inp in zip(goverrides_l, local_inputs)
]
) )
if len(all_grads_l) != len(local_inputs): else:
raise ValueError( input_grads = self._call_custom_override(lop_overrides, callable_args, nin)
"Gradient/L_op overriding function should return list of "
f"{int(inp_len)} outputs, got {len(all_grads_l)}" # Filter out disconnected input and output gradients
) connected_input_grads = [
all_grads_l = list(all_grads_l) inp_grad
all_grads_ov_l = list(all_grads_ov_l) for inp_grad in input_grads
self._lop_op = type(self)( if not isinstance(inp_grad.type, DisconnectedType | NullType)
inputs=local_inputs + local_outputs + output_grads, ]
outputs=all_grads_l, lop_op = type(self)(
inputs=inner_inputs + inner_outputs + output_grads,
outputs=connected_input_grads,
inline=self.is_inline, inline=self.is_inline,
name=(None if self.name is None else self.name + "_" + self._lop_type), name=(None if self.name is None else f"{self.name}_LOp"),
# TODO: We can be eager here and exclude unused inputs in the OFG
on_unused_input="ignore", on_unused_input="ignore",
) )
self._lop_op_stypes_l = all_grads_ov_l
self._lop_op_is_cached = True # Return a wrapper that combines connected and disconnected input gradients
self._lop_type = "lop" def wrapper(*inputs: Variable, **kwargs) -> list[Variable]:
connected_input_grads = iter(lop_op(*inputs, **kwargs))
return [
input_grad
if isinstance(input_grad.type, DisconnectedType | NullType)
else next(connected_input_grads)
for input_grad in input_grads
]
self._lop_op_cache = wrapper
return wrapper
@config.change_flags(compute_test_value="off") @config.change_flags(compute_test_value="off")
def _recompute_rop_op(self): def _build_and_cache_rop_op(self):
""" """Converts rop_overrides from user supplied form to type(self) instance.
converts self._rop_op from user supplied form to type(self) instance
Results are cached in self._rop_op_cache
""" """
local_inputs = self.inner_inputs if self._rop_op_cache is not None:
local_outputs = self.inner_outputs return self._rop_op_cache
out_len = len(local_outputs)
rop_op = self._rop_op inner_inputs = self.inner_inputs
inner_outputs = self.inner_outputs
if isinstance(rop_op, OpFromGraph): nout = len(inner_outputs)
if not self._rop_op_is_cached: rop_overrides = self.rop_overrides
self._rop_op_is_cached = True
self._rop_op_stypes_l = [None] * out_len if isinstance(rop_overrides, OpFromGraph):
return self._rop_op_cache = rop_overrides
return rop_overrides
eval_points = [inp_t() for inp_t in self.input_types] eval_points = [inp_t() for inp_t in self.input_types]
fn_rop = partial(Rop, wrt=local_inputs, eval_points=eval_points) fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points)
TYPE_ERR_MSG = (
"R_op overrides should be (single or list of)" callable_args = (inner_inputs, eval_points)
"OpFromGraph, None, a list or a callable, got %s" if rop_overrides is None:
) output_grads = fn_rop(f=inner_outputs)
STYPE_ERR_MSG = ( elif isinstance(rop_overrides, list):
"Overriding Variable instance can only have type" custom_output_grads = rop_overrides
" of DisconnectedType or NullType, got %s" if len(custom_output_grads) != nout:
)
if rop_op is None:
rdefaults_l = fn_rop(f=local_outputs)
all_rops_l, all_rops_ov_l = zip(
*[
OpFromGraph._filter_rop_var(rop, out)
for rop, out in zip(rdefaults_l, local_outputs)
]
)
all_rops_l = list(all_rops_l)
all_rops_ov_l = list(all_rops_ov_l)
elif isinstance(rop_op, list):
roverrides_l = rop_op
if len(roverrides_l) != out_len:
raise ValueError( raise ValueError(
f"Need to override {int(out_len)} Rop, got {len(roverrides_l)}", f"Need to override {int(nout)} Rop, got {len(custom_output_grads)}",
roverrides_l, custom_output_grads,
) )
# get outputs that does not have Rop override # get outputs that does not have Rop override
odefaults_l = [ f = [
lo for lo, rov in zip(local_outputs, roverrides_l) if rov is None output
for output, custom_output_grad in zip(
inner_outputs, custom_output_grads
)
if custom_output_grad is None
] ]
rdefaults_l = fn_rop(f=odefaults_l) default_output_grads = fn_rop(f=f) if f else []
rdefaults = iter(rdefaults_l if odefaults_l else []) output_grads = self._combine_list_overrides(
# combine overriding Rops default_output_grads, custom_output_grads, callable_args
all_rops_l = [] )
all_rops_ov_l = []
for out, fn_rov in zip(local_outputs, roverrides_l):
if fn_rov is None:
rnext, rnext_ov = OpFromGraph._filter_rop_var(next(rdefaults), out)
all_rops_l.append(rnext)
all_rops_ov_l.append(rnext_ov)
elif isinstance(fn_rov, Variable):
if isinstance(fn_rov.type, NullType):
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(fn_rov.type())
if isinstance(fn_rov.type, DisconnectedType):
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(None)
else:
raise ValueError(STYPE_ERR_MSG % fn_rov.type)
else:
if not callable(fn_rov):
raise TypeError(TYPE_ERR_MSG % fn_rov)
rov, rov_ov = OpFromGraph._filter_rop_var(
fn_rov(local_inputs, eval_points), out
)
all_rops_l.append(rov)
all_rops_ov_l.append(rov_ov)
else: else:
if not callable(rop_op): output_grads = self._call_custom_override(
raise TypeError(TYPE_ERR_MSG % rop_op) rop_overrides, callable_args, nout
roverrides_l = rop_op(local_inputs, eval_points)
if not isinstance(roverrides_l, list):
raise TypeError(
"Rop overriding function should return a list, "
f'got "{type(roverrides_l)}"'
)
all_rops_l, all_rops_ov_l = zip(
*[
OpFromGraph._filter_rop_var(rop, out)
for rop, out in zip(roverrides_l, local_outputs)
]
) )
if len(all_rops_l) != out_len:
raise ValueError( # Filter out disconnected output gradients
( filtered_output_grads = [
f"Rop overriding function {self._rop_op} should return list of " out_grad
f"{int(out_len)} outputs, got {len(all_rops_l)}", for out_grad in output_grads
), if not isinstance(out_grad.type, DisconnectedType | NullType)
rop_op, ]
) rop_op = type(self)(
all_rops_l = list(all_rops_l) inputs=inner_inputs + eval_points,
all_rops_ov_l = list(all_rops_ov_l) outputs=filtered_output_grads,
self._rop_op = type(self)(
inputs=local_inputs + eval_points,
outputs=all_rops_l,
inline=self.is_inline, inline=self.is_inline,
name=(None if self.name is None else self.name + "_rop"), name=(None if self.name is None else self.name + "_rop"),
on_unused_input="ignore", on_unused_input="ignore",
) )
self._rop_op_stypes_l = all_rops_ov_l
self._rop_op_is_cached = True
def get_lop_op(self):
if not self._lop_op_is_cached:
self._recompute_lop_op()
return self._lop_op
def get_rop_op(self):
if not self._rop_op_is_cached:
self._recompute_rop_op()
return self._rop_op
def set_grad_overrides(self, grad_overrides):
"""
Set gradient overrides.
This will completely remove any previously set L_op/gradient overrides
"""
self._lop_op = grad_overrides
self._lop_op_is_cached = False
self._lop_type = "grad"
def set_lop_overrides(self, lop_overrides):
"""
Set L_op overrides
This will completely remove any previously set L_op/gradient overrides
"""
self._lop_op = lop_overrides
self._lop_op_is_cached = False
self._lop_type = "lop"
def set_rop_overrides(self, rop_overrides): # Return a wrapper that combines connected and disconnected output gradients
""" def wrapper(*inputs: Variable, **kwargs) -> list[Variable | None]:
Set R_op overrides connected_output_grads = iter(rop_op(*inputs, **kwargs))
This will completely remove any previously set R_op overrides all_output_grads = []
for out_grad in output_grads:
if isinstance(out_grad.type, DisconnectedType):
# R_Op does not have DisconnectedType yet, None should be used instead
all_output_grads.append(None)
elif isinstance(out_grad.type, NullType):
all_output_grads.append(out_grad)
else:
all_output_grads.append(next(connected_output_grads))
return all_output_grads
""" self._rop_op_cache = wrapper
self._rop_op = rop_overrides return wrapper
self._rop_op_is_cached = False
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
if not self._lop_op_is_cached: lop_op = self._build_and_cache_lop_op()
self._recompute_lop_op() return lop_op(*inputs, *outputs, *output_grads, return_list=True)
inps = list(inputs) + list(outputs) + list(output_grads)
ret_ofg_l = self._lop_op(*inps, return_list=True)
ret_l = [
ret_ofg if ov is None else ov
for ret_ofg, ov in zip(ret_ofg_l, self._lop_op_stypes_l)
]
return ret_l
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if not self._rop_op_is_cached: rop_op = self._build_and_cache_rop_op()
self._recompute_rop_op() return rop_op(*inputs, *eval_points, return_list=True)
ret_ofg_l = self._rop_op(*(list(inputs) + list(eval_points)), return_list=True)
ret_l = [
ret_ofg if ov is None else ov
for ret_ofg, ov in zip(ret_ofg_l, self._rop_op_stypes_l)
]
return ret_l
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
# The user interface doesn't expect the shared variable inputs of the # The user interface doesn't expect the shared variable inputs of the
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论