提交 35f0df96 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Make params exclusive to COp's

Also removes them from the signature of perform
上级 5f840278
.. _extending_op_params: .. _extending_op_params:
=============== ================
Using Op params Using COp params
=============== ================
The Op params is a facility to pass some runtime parameters to the The COp params is a facility to pass some runtime parameters to the
code of an op without modifying it. It can enable a single instance code of an op without modifying it. It can enable a single instance
of C code to serve different needs and therefore reduce compilation. of C code to serve different needs and therefore reduce compilation.
...@@ -53,7 +53,7 @@ following methods will be used for the type: ...@@ -53,7 +53,7 @@ following methods will be used for the type:
- :meth:`__hash__ <Type.__hash__>` - :meth:`__hash__ <Type.__hash__>`
- :meth:`values_eq <Type.values_eq>` - :meth:`values_eq <Type.values_eq>`
Additionally if you want to use your params with C code, you need to extend `COp` Additionally, to use your params with C code, you need to extend `COp`
and implement the following methods: and implement the following methods:
- :meth:`c_declare <CLinkerType.c_declare>` - :meth:`c_declare <CLinkerType.c_declare>`
...@@ -65,24 +65,24 @@ You can also define other convenience methods such as ...@@ -65,24 +65,24 @@ You can also define other convenience methods such as
:meth:`c_headers <CLinkerType.c_headers>` if you need any special things. :meth:`c_headers <CLinkerType.c_headers>` if you need any special things.
Registering the params with your Op Registering the params with your COp
----------------------------------- ------------------------------------
To declare that your Op uses params you have to set the class To declare that your `COp` uses params you have to set the class
attribute :attr:`params_type` to an instance of your params Type. attribute :attr:`params_type` to an instance of your params Type.
.. note:: .. note::
If you want to have multiple parameters, PyTensor provides the convenient class If you want to have multiple parameters, PyTensor provides the convenient class
:class:`pytensor.link.c.params_type.ParamsType` that allows to bundle many parameters into :class:`pytensor.link.c.params_type.ParamsType` that allows to bundle many parameters into
one object that will be available in both Python (as a Python object) and C code (as a struct). one object that will be available to the C code (as a struct).
For example if we decide to use an int as the params the following For example if we decide to use an int as the params the following
would be appropriate: would be appropriate:
.. code-block:: python .. code-block:: python
class MyOp(Op): class MyOp(COp):
params_type = Generic() params_type = Generic()
After that you need to define a :meth:`get_params` method on your After that you need to define a :meth:`get_params` method on your
...@@ -115,12 +115,7 @@ Having declared a params for your Op will affect the expected ...@@ -115,12 +115,7 @@ Having declared a params for your Op will affect the expected
signature of :meth:`perform`. The new expected signature will have an signature of :meth:`perform`. The new expected signature will have an
extra parameter at the end which corresponds to the params object. extra parameter at the end which corresponds to the params object.
.. warning:: The `sub` dictionary for `COp`s with params will contain an extra entry
If you do not account for this extra parameter, the code will fail
at runtime if it tries to run the python version.
Also, for the C code, the `sub` dictionary will contain an extra entry
`'params'` which will map to the variable name of the params object. `'params'` which will map to the variable name of the params object.
This is true for all methods that receive a `sub` parameter, so this This is true for all methods that receive a `sub` parameter, so this
means that you can use your params in the :meth:`c_code <COp.c_code>` means that you can use your params in the :meth:`c_code <COp.c_code>`
...@@ -131,7 +126,7 @@ A simple example ...@@ -131,7 +126,7 @@ A simple example
---------------- ----------------
This is a simple example which uses a params object to pass a value. This is a simple example which uses a params object to pass a value.
This `Op` will multiply a scalar input by a fixed floating point value. This `COp` will multiply a scalar input by a fixed floating point value.
Since the value in this case is a python float, we chose Generic as Since the value in this case is a python float, we chose Generic as
the params type. the params type.
...@@ -156,9 +151,10 @@ the params type. ...@@ -156,9 +151,10 @@ the params type.
inp = as_scalar(inp) inp = as_scalar(inp)
return Apply(self, [inp], [inp.type()]) return Apply(self, [inp], [inp.type()])
def perform(self, node, inputs, output_storage, params): def perform(self, node, inputs, output_storage):
# Here params is a python float so this is ok # Because params is a python float we can use `self.mul` directly.
output_storage[0][0] = inputs[0] * params # If it's something fancier, call `self.params_type.filter(self.get_params(node))`
output_storage[0][0] = inputs[0] * self.mul
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
return ("%(z)s = %(x)s * PyFloat_AsDouble(%(p)s);" % return ("%(z)s = %(x)s * PyFloat_AsDouble(%(p)s);" %
...@@ -174,7 +170,7 @@ weights. ...@@ -174,7 +170,7 @@ weights.
.. testcode:: .. testcode::
from pytensor.graph.op import Op from pytensor.link.c.op import COp
from pytensor.link.c.type import Generic from pytensor.link.c.type import Generic
from pytensor.scalar import as_scalar from pytensor.scalar import as_scalar
......
...@@ -30,7 +30,6 @@ import numpy as np ...@@ -30,7 +30,6 @@ import numpy as np
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.utils import ( from pytensor.graph.utils import (
MetaObject, MetaObject,
MethodNotDefined,
Scratchpad, Scratchpad,
TestValueError, TestValueError,
ValidatingScratchpad, ValidatingScratchpad,
...@@ -151,16 +150,6 @@ class Apply(Node, Generic[OpType]): ...@@ -151,16 +150,6 @@ class Apply(Node, Generic[OpType]):
f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}" f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}"
) )
def run_params(self):
"""
Returns the params for the node, or NoParams if no params is set.
"""
try:
return self.op.get_params(self)
except MethodNotDefined:
return NoParams
def __getstate__(self): def __getstate__(self):
d = self.__dict__ d = self.__dict__
# ufunc don't pickle/unpickle well # ufunc don't pickle/unpickle well
......
...@@ -16,15 +16,13 @@ from typing import ( ...@@ -16,15 +16,13 @@ from typing import (
import pytensor import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, NoParams, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.utils import ( from pytensor.graph.utils import (
MetaObject, MetaObject,
MethodNotDefined,
TestValueError, TestValueError,
add_tag_trace, add_tag_trace,
get_variable_trace_string, get_variable_trace_string,
) )
from pytensor.link.c.params_type import Params, ParamsType
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -37,10 +35,7 @@ StorageMapType = dict[Variable, StorageCellType] ...@@ -37,10 +35,7 @@ StorageMapType = dict[Variable, StorageCellType]
ComputeMapType = dict[Variable, list[bool]] ComputeMapType = dict[Variable, list[bool]]
InputStorageType = list[StorageCellType] InputStorageType = list[StorageCellType]
OutputStorageType = list[StorageCellType] OutputStorageType = list[StorageCellType]
ParamsInputType = Optional[tuple[Any, ...]] PerformMethodType = Callable[[Apply, list[Any], OutputStorageType], None]
PerformMethodType = Callable[
[Apply, list[Any], OutputStorageType, ParamsInputType], None
]
BasicThunkType = Callable[[], None] BasicThunkType = Callable[[], None]
ThunkCallableType = Callable[ ThunkCallableType = Callable[
[PerformMethodType, StorageMapType, ComputeMapType, Apply], None [PerformMethodType, StorageMapType, ComputeMapType, Apply], None
...@@ -202,7 +197,6 @@ class Op(MetaObject): ...@@ -202,7 +197,6 @@ class Op(MetaObject):
itypes: Optional[Sequence["Type"]] = None itypes: Optional[Sequence["Type"]] = None
otypes: Optional[Sequence["Type"]] = None otypes: Optional[Sequence["Type"]] = None
params_type: Optional[ParamsType] = None
_output_type_depends_on_input_value = False _output_type_depends_on_input_value = False
""" """
...@@ -426,7 +420,6 @@ class Op(MetaObject): ...@@ -426,7 +420,6 @@ class Op(MetaObject):
node: Apply, node: Apply,
inputs: Sequence[Any], inputs: Sequence[Any],
output_storage: OutputStorageType, output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None: ) -> None:
"""Calculate the function on the inputs and put the variables in the output storage. """Calculate the function on the inputs and put the variables in the output storage.
...@@ -442,8 +435,6 @@ class Op(MetaObject): ...@@ -442,8 +435,6 @@ class Op(MetaObject):
these lists). Each sub-list corresponds to value of each these lists). Each sub-list corresponds to value of each
`Variable` in :attr:`node.outputs`. The primary purpose of this method `Variable` in :attr:`node.outputs`. The primary purpose of this method
is to set the values of these sub-lists. is to set the values of these sub-lists.
params
A tuple containing the values of each entry in :attr:`Op.__props__`.
Notes Notes
----- -----
...@@ -481,22 +472,6 @@ class Op(MetaObject): ...@@ -481,22 +472,6 @@ class Op(MetaObject):
""" """
return True return True
def get_params(self, node: Apply) -> Params:
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
if isinstance(self.params_type, ParamsType):
wrapper = self.params_type
if not all(hasattr(self, field) for field in wrapper.fields):
# Let's print missing attributes for debugging.
not_found = tuple(
field for field in wrapper.fields if not hasattr(self, field)
)
raise AttributeError(
f"{type(self).__name__}: missing attributes {not_found} for ParamsType."
)
# ParamsType.get_params() will apply filtering to attributes.
return self.params_type.get_params(self)
raise MethodNotDefined("get_params")
def prepare_node( def prepare_node(
self, self,
node: Apply, node: Apply,
...@@ -538,34 +513,12 @@ class Op(MetaObject): ...@@ -538,34 +513,12 @@ class Op(MetaObject):
else: else:
p = node.op.perform p = node.op.perform
params = node.run_params() @is_thunk_type
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
if params is NoParams: r = p(n, [x[0] for x in i], o)
# default arguments are stored in the closure of `rval` for o in node.outputs:
@is_thunk_type compute_map[o][0] = True
def rval( return r
p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
return r
else:
params_val = node.params_type.filter(params)
@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
params=params_val,
):
r = p(n, [x[0] for x in i], o, params)
for o in node.outputs:
compute_map[o][0] = True
return r
rval.inputs = node_input_storage rval.inputs = node_input_storage
rval.outputs = node_output_storage rval.outputs = node_output_storage
...@@ -640,7 +593,7 @@ class _NoPythonOp(Op): ...@@ -640,7 +593,7 @@ class _NoPythonOp(Op):
""" """
def perform(self, node, inputs, output_storage, params=None): def perform(self, node, inputs, output_storage):
raise NotImplementedError("No Python implementation is provided by this Op.") raise NotImplementedError("No Python implementation is provided by this Op.")
......
...@@ -20,6 +20,7 @@ from pytensor.graph.basic import ( ...@@ -20,6 +20,7 @@ from pytensor.graph.basic import (
io_toposort, io_toposort,
vars_between, vars_between,
) )
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.basic import Container, Linker, LocalLinker, PerformLinker from pytensor.link.basic import Container, Linker, LocalLinker, PerformLinker
from pytensor.link.c.cmodule import ( from pytensor.link.c.cmodule import (
METH_VARARGS, METH_VARARGS,
...@@ -617,7 +618,12 @@ class CLinker(Linker): ...@@ -617,7 +618,12 @@ class CLinker(Linker):
# that needs it # that needs it
self.node_params = dict() self.node_params = dict()
for node in self.node_order: for node in self.node_order:
params = node.run_params() if not isinstance(node.op, CLinkerOp):
continue
try:
params = node.op.get_params(node)
except MethodNotDefined:
params = NoParams
if params is not NoParams: if params is not NoParams:
# try to avoid creating more than one variable for the # try to avoid creating more than one variable for the
# same params. # same params.
...@@ -803,7 +809,10 @@ class CLinker(Linker): ...@@ -803,7 +809,10 @@ class CLinker(Linker):
sub = dict(failure_var=failure_var) sub = dict(failure_var=failure_var)
params = node.run_params() try:
params = op.get_params(node)
except MethodNotDefined:
params = NoParams
if params is not NoParams: if params is not NoParams:
params_var = symbol[self.node_params[params]] params_var = symbol[self.node_params[params]]
......
import typing
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from typing import Callable from typing import Callable, Optional
from pytensor.graph.basic import Apply, Constant from pytensor.graph.basic import Apply, Constant
from pytensor.graph.utils import MethodNotDefined from pytensor.graph.utils import MethodNotDefined
if typing.TYPE_CHECKING:
from pytensor.link.c.params_type import Params, ParamsType
class CLinkerObject: class CLinkerObject:
"""Standard methods for an `Op` or `Type` used with the `CLinker`.""" """Standard methods for an `Op` or `Type` used with the `CLinker`."""
...@@ -172,6 +177,8 @@ class CLinkerObject: ...@@ -172,6 +177,8 @@ class CLinkerObject:
class CLinkerOp(CLinkerObject): class CLinkerOp(CLinkerObject):
"""Interface definition for `Op` subclasses compiled by `CLinker`.""" """Interface definition for `Op` subclasses compiled by `CLinker`."""
params_type: Optional["ParamsType"] = None
@abstractmethod @abstractmethod
def c_code( def c_code(
self, self,
...@@ -362,6 +369,22 @@ class CLinkerOp(CLinkerObject): ...@@ -362,6 +369,22 @@ class CLinkerOp(CLinkerObject):
""" """
return "" return ""
def get_params(self, node: Apply) -> "Params":
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
if self.params_type is not None:
wrapper = self.params_type
if not all(hasattr(self, field) for field in wrapper.fields):
# Let's print missing attributes for debugging.
not_found = tuple(
field for field in wrapper.fields if not hasattr(self, field)
)
raise AttributeError(
f"{type(self).__name__}: missing attributes {not_found} for ParamsType."
)
# ParamsType.get_params() will apply filtering to attributes.
return self.params_type.get_params(self)
raise MethodNotDefined("get_params")
class CLinkerType(CLinkerObject): class CLinkerType(CLinkerObject):
r"""Interface specification for `Type`\s that can be arguments to a `CLinkerOp`. r"""Interface specification for `Type`\s that can be arguments to a `CLinkerOp`.
......
...@@ -664,7 +664,7 @@ class _NoPythonCOp(COp): ...@@ -664,7 +664,7 @@ class _NoPythonCOp(COp):
""" """
def perform(self, node, inputs, output_storage, params=None): def perform(self, node, inputs, output_storage):
raise NotImplementedError("No Python implementation is provided by this COp.") raise NotImplementedError("No Python implementation is provided by this COp.")
...@@ -675,7 +675,7 @@ class _NoPythonExternalCOp(ExternalCOp): ...@@ -675,7 +675,7 @@ class _NoPythonExternalCOp(ExternalCOp):
""" """
def perform(self, node, inputs, output_storage, params=None): def perform(self, node, inputs, output_storage):
raise NotImplementedError( raise NotImplementedError(
"No Python implementation is provided by this ExternalCOp." "No Python implementation is provided by this ExternalCOp."
) )
...@@ -21,7 +21,7 @@ from numba.extending import box, overload ...@@ -21,7 +21,7 @@ from numba.extending import box, overload
from pytensor import config from pytensor import config
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply, NoParams from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.ifelse import IfElse from pytensor.ifelse import IfElse
...@@ -383,22 +383,11 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): ...@@ -383,22 +383,11 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
ret_sig = get_numba_type(node.outputs[0].type) ret_sig = get_numba_type(node.outputs[0].type)
output_types = tuple(out.type for out in node.outputs) output_types = tuple(out.type for out in node.outputs)
params = node.run_params()
if params is not NoParams: def py_perform(inputs):
params_val = dict(node.params_type.filter(params)) outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
def py_perform(inputs): return outputs
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs, params_val)
return outputs
else:
def py_perform(inputs):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
return outputs
if n_outputs == 1: if n_outputs == 1:
......
...@@ -90,7 +90,7 @@ class CheckAndRaise(COp): ...@@ -90,7 +90,7 @@ class CheckAndRaise(COp):
[value.type()], [value.type()],
) )
def perform(self, node, inputs, outputs, params): def perform(self, node, inputs, outputs):
(out,) = outputs (out,) = outputs
val, *conds = inputs val, *conds = inputs
out[0] = val out[0] = val
......
...@@ -1658,7 +1658,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1658,7 +1658,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rval.lazy = False rval.lazy = False
return rval return rval
def perform(self, node, inputs, output_storage, params=None): def perform(self, node, inputs, output_storage):
"""Compute the scan operation in Python. """Compute the scan operation in Python.
The `inputs` are packed like this: The `inputs` are packed like this:
......
...@@ -3991,11 +3991,11 @@ class AllocEmpty(COp): ...@@ -3991,11 +3991,11 @@ class AllocEmpty(COp):
output.tag.nan_guard_mode_check = False output.tag.nan_guard_mode_check = False
return Apply(self, _shape, [output]) return Apply(self, _shape, [output])
def debug_perform(self, node, inputs, out_, params): def debug_perform(self, node, inputs, out_):
self.perform(node, inputs, out_, params) self.perform(node, inputs, out_)
out_[0][0].fill(-123456789) out_[0][0].fill(-123456789)
def perform(self, node, inputs, out_, params): def perform(self, node, inputs, out_):
(out,) = out_ (out,) = out_
sh = tuple([int(i) for i in inputs]) sh = tuple([int(i) for i in inputs])
if out[0] is None or out[0].shape != sh: if out[0] is None or out[0].shape != sh:
......
...@@ -207,7 +207,7 @@ class Gemv(Op): ...@@ -207,7 +207,7 @@ class Gemv(Op):
return Apply(self, inputs, [y.type()]) return Apply(self, inputs, [y.type()])
def perform(self, node, inputs, out_storage, params=None): def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs y, alpha, A, x, beta = inputs
if ( if (
have_fblas have_fblas
...@@ -309,7 +309,7 @@ class Ger(Op): ...@@ -309,7 +309,7 @@ class Ger(Op):
return Apply(self, inputs, [A.type()]) return Apply(self, inputs, [A.type()])
def perform(self, node, inp, out, params=None): def perform(self, node, inp, out):
cA, calpha, cx, cy = inp cA, calpha, cx, cy = inp
(cZ,) = out (cZ,) = out
if self.destructive: if self.destructive:
...@@ -912,12 +912,12 @@ class Gemm(GemmRelated): ...@@ -912,12 +912,12 @@ class Gemm(GemmRelated):
output = z.type() output = z.type()
return Apply(self, inputs, [output]) return Apply(self, inputs, [output])
def perform(self, node, inp, out, params): def perform(self, node, inp, out):
z, a, x, y, b = inp z, a, x, y, b = inp
(zout,) = out (zout,) = out
assert a.shape == () assert a.shape == ()
assert b.shape == () assert b.shape == ()
if not params.inplace: if not self.inplace:
z = z.copy() # the original z will not be changed z = z.copy() # the original z will not be changed
if z.shape == (): if z.shape == ():
z.itemset(z * a + b * np.dot(x, y)) z.itemset(z * a + b * np.dot(x, y))
......
...@@ -233,7 +233,7 @@ class DimShuffle(ExternalCOp): ...@@ -233,7 +233,7 @@ class DimShuffle(ExternalCOp):
return f"Transpose{{axes={self.shuffle}}}" return f"Transpose{{axes={self.shuffle}}}"
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
def perform(self, node, inp, out, params=None): def perform(self, node, inp, out):
(res,) = inp (res,) = inp
(storage,) = out (storage,) = out
......
...@@ -145,7 +145,7 @@ class SearchsortedOp(COp): ...@@ -145,7 +145,7 @@ class SearchsortedOp(COp):
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[1]] return [shapes[1]]
def perform(self, node, inputs, output_storage, params): def perform(self, node, inputs, output_storage):
x = inputs[0] x = inputs[0]
v = inputs[1] v = inputs[1]
if len(node.inputs) == 3: if len(node.inputs) == 3:
...@@ -154,7 +154,7 @@ class SearchsortedOp(COp): ...@@ -154,7 +154,7 @@ class SearchsortedOp(COp):
sorter = None sorter = None
z = output_storage[0] z = output_storage[0]
z[0] = np.searchsorted(x, v, side=params, sorter=sorter).astype( z[0] = np.searchsorted(x, v, side=self.side, sorter=sorter).astype(
node.outputs[0].dtype node.outputs[0].dtype
) )
...@@ -310,7 +310,7 @@ class CumOp(COp): ...@@ -310,7 +310,7 @@ class CumOp(COp):
return Apply(self, [x], [out_type]) return Apply(self, [x], [out_type])
def perform(self, node, inputs, output_storage, params): def perform(self, node, inputs, output_storage):
x = inputs[0] x = inputs[0]
z = output_storage[0] z = output_storage[0]
if self.mode == "add": if self.mode == "add":
......
...@@ -152,9 +152,9 @@ class MaxAndArgmax(COp): ...@@ -152,9 +152,9 @@ class MaxAndArgmax(COp):
] ]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, inp, outs, params): def perform(self, node, inp, outs):
x = inp[0] x = inp[0]
axes = params axes = self.axis
max, max_idx = outs max, max_idx = outs
if axes is None: if axes is None:
axes = tuple(range(x.ndim)) axes = tuple(range(x.ndim))
...@@ -374,7 +374,7 @@ class Argmax(COp): ...@@ -374,7 +374,7 @@ class Argmax(COp):
"You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format." "You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format."
) )
def perform(self, node, inp, outs, params): def perform(self, node, inp, outs):
(x,) = inp (x,) = inp
axes = self.axis axes = self.axis
(max_idx,) = outs (max_idx,) = outs
......
...@@ -48,7 +48,7 @@ def local_max_and_argmax(fgraph, node): ...@@ -48,7 +48,7 @@ def local_max_and_argmax(fgraph, node):
If we don't use the argmax, change it to a max only. If we don't use the argmax, change it to a max only.
""" """
if isinstance(node.op, MaxAndArgmax): if isinstance(node.op, MaxAndArgmax):
axis = node.op.get_params(node) axis = node.op.axis
if len(fgraph.clients[node.outputs[1]]) == 0: if len(fgraph.clients[node.outputs[1]]) == 0:
new = Max(axis)(node.inputs[0]) new = Max(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new) copy_stack_trace(node.outputs[0], new)
......
...@@ -237,7 +237,7 @@ class Shape_i(COp): ...@@ -237,7 +237,7 @@ class Shape_i(COp):
raise TypeError(f"{x} has too few dimensions for Shape_i") raise TypeError(f"{x} has too few dimensions for Shape_i")
return Apply(self, [x], [pytensor.tensor.type.lscalar()]) return Apply(self, [x], [pytensor.tensor.type.lscalar()])
def perform(self, node, inp, out_, params): def perform(self, node, inp, out_):
(x,) = inp (x,) = inp
(out,) = out_ (out,) = out_
if out[0] is None: if out[0] is None:
...@@ -668,7 +668,7 @@ class Reshape(COp): ...@@ -668,7 +668,7 @@ class Reshape(COp):
return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)]) return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])
def perform(self, node, inp, out_, params=None): def perform(self, node, inp, out_):
x, shp = inp x, shp = inp
(out,) = out_ (out,) = out_
if len(shp) != self.ndim: if len(shp) != self.ndim:
......
...@@ -2474,7 +2474,7 @@ class AdvancedIncSubtensor1(COp): ...@@ -2474,7 +2474,7 @@ class AdvancedIncSubtensor1(COp):
def c_code_cache_version(self): def c_code_cache_version(self):
return (8,) return (8,)
def perform(self, node, inp, out_, params): def perform(self, node, inp, out_):
x, y, idx = inp x, y, idx = inp
(out,) = out_ (out,) = out_
if not self.inplace: if not self.inplace:
......
...@@ -31,7 +31,8 @@ class QuadraticOpFunc(COp): ...@@ -31,7 +31,8 @@ class QuadraticOpFunc(COp):
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x)
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def perform(self, node, inputs, output_storage, coefficients): def perform(self, node, inputs, output_storage):
coefficients = self.params_type.filter(self.get_params(node))
x = inputs[0] x = inputs[0]
y = output_storage[0] y = output_storage[0]
y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c
...@@ -117,7 +118,8 @@ class QuadraticCOpFunc(ExternalCOp): ...@@ -117,7 +118,8 @@ class QuadraticCOpFunc(ExternalCOp):
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x)
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def perform(self, node, inputs, output_storage, coefficients): def perform(self, node, inputs, output_storage):
coefficients = self.params_type.filter(self.get_params(node))
x = inputs[0] x = inputs[0]
y = output_storage[0] y = output_storage[0]
y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c
......
...@@ -117,7 +117,8 @@ class MyOpEnumList(COp): ...@@ -117,7 +117,8 @@ class MyOpEnumList(COp):
def make_node(self, a, b): def make_node(self, a, b):
return Apply(self, [aes.as_scalar(a), aes.as_scalar(b)], [aes.float64()]) return Apply(self, [aes.as_scalar(a), aes.as_scalar(b)], [aes.float64()])
def perform(self, node, inputs, outputs, op): def perform(self, node, inputs, outputs):
op = self.params_type.filter(self.get_params(node))
a, b = inputs a, b = inputs
(o,) = outputs (o,) = outputs
if op == self.params_type.ADD: if op == self.params_type.ADD:
......
...@@ -12,6 +12,7 @@ from pytensor.graph.replace import vectorize_node ...@@ -12,6 +12,7 @@ from pytensor.graph.replace import vectorize_node
from pytensor.tensor import diagonal, log, tensor from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.shape import Shape
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.utils import _parse_gufunc_signature
...@@ -359,3 +360,13 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm ...@@ -359,3 +360,13 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
fn = pytensor.function([value, mu, cov], [logp, *dlogp]) fn = pytensor.function([value, mu, cov], [logp, *dlogp])
benchmark(fn, *test_values) benchmark(fn, *test_values)
def test_op_with_params():
matrix_shape_blockwise = Blockwise(core_op=Shape(), signature="(x1,x2)->(s)")
x = tensor("x", shape=(5, None, None), dtype="float64")
x_shape = matrix_shape_blockwise(x)
fn = pytensor.function([x], x_shape)
pytensor.dprint(fn)
# Assert blockwise
print(fn(np.zeros((5, 3, 2))))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论