提交 35f0df96 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Make params exclusive to COp's

Also removes them from the signature of perform
上级 5f840278
.. _extending_op_params:
===============
Using Op params
===============
================
Using COp params
================
The Op params is a facility to pass some runtime parameters to the
The COp params is a facility to pass some runtime parameters to the
code of an op without modifying it. It can enable a single instance
of C code to serve different needs and therefore reduce compilation.
......@@ -53,7 +53,7 @@ following methods will be used for the type:
- :meth:`__hash__ <Type.__hash__>`
- :meth:`values_eq <Type.values_eq>`
Additionally if you want to use your params with C code, you need to extend `COp`
Additionally, to use your params with C code, you need to extend `COp`
and implement the following methods:
- :meth:`c_declare <CLinkerType.c_declare>`
......@@ -65,24 +65,24 @@ You can also define other convenience methods such as
:meth:`c_headers <CLinkerType.c_headers>` if you need any special things.
Registering the params with your Op
-----------------------------------
Registering the params with your COp
------------------------------------
To declare that your Op uses params you have to set the class
To declare that your `COp` uses params you have to set the class
attribute :attr:`params_type` to an instance of your params Type.
.. note::
If you want to have multiple parameters, PyTensor provides the convenient class
:class:`pytensor.link.c.params_type.ParamsType` that allows to bundle many parameters into
one object that will be available in both Python (as a Python object) and C code (as a struct).
one object that will be available to the C code (as a struct).
For example if we decide to use an int as the params the following
would be appropriate:
.. code-block:: python
class MyOp(Op):
class MyOp(COp):
params_type = Generic()
After that you need to define a :meth:`get_params` method on your
......@@ -115,12 +115,7 @@ Having declared a params for your Op will affect the expected
signature of :meth:`perform`. The new expected signature will have an
extra parameter at the end which corresponds to the params object.
.. warning::
If you do not account for this extra parameter, the code will fail
at runtime if it tries to run the python version.
Also, for the C code, the `sub` dictionary will contain an extra entry
The `sub` dictionary for `COp`s with params will contain an extra entry
`'params'` which will map to the variable name of the params object.
This is true for all methods that receive a `sub` parameter, so this
means that you can use your params in the :meth:`c_code <COp.c_code>`
......@@ -131,7 +126,7 @@ A simple example
----------------
This is a simple example which uses a params object to pass a value.
This `Op` will multiply a scalar input by a fixed floating point value.
This `COp` will multiply a scalar input by a fixed floating point value.
Since the value in this case is a python float, we chose Generic as
the params type.
......@@ -156,9 +151,10 @@ the params type.
inp = as_scalar(inp)
return Apply(self, [inp], [inp.type()])
def perform(self, node, inputs, output_storage, params):
# Here params is a python float so this is ok
output_storage[0][0] = inputs[0] * params
def perform(self, node, inputs, output_storage):
# Because params is a python float we can use `self.mul` directly.
# If it's something fancier, call `self.params_type.filter(self.get_params(node))`
output_storage[0][0] = inputs[0] * self.mul
def c_code(self, node, name, inputs, outputs, sub):
return ("%(z)s = %(x)s * PyFloat_AsDouble(%(p)s);" %
......@@ -174,7 +170,7 @@ weights.
.. testcode::
from pytensor.graph.op import Op
from pytensor.link.c.op import COp
from pytensor.link.c.type import Generic
from pytensor.scalar import as_scalar
......
......@@ -30,7 +30,6 @@ import numpy as np
from pytensor.configdefaults import config
from pytensor.graph.utils import (
MetaObject,
MethodNotDefined,
Scratchpad,
TestValueError,
ValidatingScratchpad,
......@@ -151,16 +150,6 @@ class Apply(Node, Generic[OpType]):
f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}"
)
def run_params(self):
"""
Returns the params for the node, or NoParams if no params is set.
"""
try:
return self.op.get_params(self)
except MethodNotDefined:
return NoParams
def __getstate__(self):
d = self.__dict__
# ufunc don't pickle/unpickle well
......
......@@ -16,15 +16,13 @@ from typing import (
import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, NoParams, Variable
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.utils import (
MetaObject,
MethodNotDefined,
TestValueError,
add_tag_trace,
get_variable_trace_string,
)
from pytensor.link.c.params_type import Params, ParamsType
if TYPE_CHECKING:
......@@ -37,10 +35,7 @@ StorageMapType = dict[Variable, StorageCellType]
ComputeMapType = dict[Variable, list[bool]]
InputStorageType = list[StorageCellType]
OutputStorageType = list[StorageCellType]
ParamsInputType = Optional[tuple[Any, ...]]
PerformMethodType = Callable[
[Apply, list[Any], OutputStorageType, ParamsInputType], None
]
PerformMethodType = Callable[[Apply, list[Any], OutputStorageType], None]
BasicThunkType = Callable[[], None]
ThunkCallableType = Callable[
[PerformMethodType, StorageMapType, ComputeMapType, Apply], None
......@@ -202,7 +197,6 @@ class Op(MetaObject):
itypes: Optional[Sequence["Type"]] = None
otypes: Optional[Sequence["Type"]] = None
params_type: Optional[ParamsType] = None
_output_type_depends_on_input_value = False
"""
......@@ -426,7 +420,6 @@ class Op(MetaObject):
node: Apply,
inputs: Sequence[Any],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
"""Calculate the function on the inputs and put the variables in the output storage.
......@@ -442,8 +435,6 @@ class Op(MetaObject):
these lists). Each sub-list corresponds to value of each
`Variable` in :attr:`node.outputs`. The primary purpose of this method
is to set the values of these sub-lists.
params
A tuple containing the values of each entry in :attr:`Op.__props__`.
Notes
-----
......@@ -481,22 +472,6 @@ class Op(MetaObject):
"""
return True
def get_params(self, node: Apply) -> Params:
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
if isinstance(self.params_type, ParamsType):
wrapper = self.params_type
if not all(hasattr(self, field) for field in wrapper.fields):
# Let's print missing attributes for debugging.
not_found = tuple(
field for field in wrapper.fields if not hasattr(self, field)
)
raise AttributeError(
f"{type(self).__name__}: missing attributes {not_found} for ParamsType."
)
# ParamsType.get_params() will apply filtering to attributes.
return self.params_type.get_params(self)
raise MethodNotDefined("get_params")
def prepare_node(
self,
node: Apply,
......@@ -538,34 +513,12 @@ class Op(MetaObject):
else:
p = node.op.perform
params = node.run_params()
if params is NoParams:
# default arguments are stored in the closure of `rval`
@is_thunk_type
def rval(
p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
return r
else:
params_val = node.params_type.filter(params)
@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
params=params_val,
):
r = p(n, [x[0] for x in i], o, params)
for o in node.outputs:
compute_map[o][0] = True
return r
@is_thunk_type
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
return r
rval.inputs = node_input_storage
rval.outputs = node_output_storage
......@@ -640,7 +593,7 @@ class _NoPythonOp(Op):
"""
def perform(self, node, inputs, output_storage, params=None):
def perform(self, node, inputs, output_storage):
raise NotImplementedError("No Python implementation is provided by this Op.")
......
......@@ -20,6 +20,7 @@ from pytensor.graph.basic import (
io_toposort,
vars_between,
)
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.basic import Container, Linker, LocalLinker, PerformLinker
from pytensor.link.c.cmodule import (
METH_VARARGS,
......@@ -617,7 +618,12 @@ class CLinker(Linker):
# that needs it
self.node_params = dict()
for node in self.node_order:
params = node.run_params()
if not isinstance(node.op, CLinkerOp):
continue
try:
params = node.op.get_params(node)
except MethodNotDefined:
params = NoParams
if params is not NoParams:
# try to avoid creating more than one variable for the
# same params.
......@@ -803,7 +809,10 @@ class CLinker(Linker):
sub = dict(failure_var=failure_var)
params = node.run_params()
try:
params = op.get_params(node)
except MethodNotDefined:
params = NoParams
if params is not NoParams:
params_var = symbol[self.node_params[params]]
......
import typing
import warnings
from abc import abstractmethod
from typing import Callable
from typing import Callable, Optional
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.utils import MethodNotDefined
if typing.TYPE_CHECKING:
from pytensor.link.c.params_type import Params, ParamsType
class CLinkerObject:
"""Standard methods for an `Op` or `Type` used with the `CLinker`."""
......@@ -172,6 +177,8 @@ class CLinkerObject:
class CLinkerOp(CLinkerObject):
"""Interface definition for `Op` subclasses compiled by `CLinker`."""
params_type: Optional["ParamsType"] = None
@abstractmethod
def c_code(
self,
......@@ -362,6 +369,22 @@ class CLinkerOp(CLinkerObject):
"""
return ""
def get_params(self, node: Apply) -> "Params":
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
if self.params_type is not None:
wrapper = self.params_type
if not all(hasattr(self, field) for field in wrapper.fields):
# Let's print missing attributes for debugging.
not_found = tuple(
field for field in wrapper.fields if not hasattr(self, field)
)
raise AttributeError(
f"{type(self).__name__}: missing attributes {not_found} for ParamsType."
)
# ParamsType.get_params() will apply filtering to attributes.
return self.params_type.get_params(self)
raise MethodNotDefined("get_params")
class CLinkerType(CLinkerObject):
r"""Interface specification for `Type`\s that can be arguments to a `CLinkerOp`.
......
......@@ -664,7 +664,7 @@ class _NoPythonCOp(COp):
"""
def perform(self, node, inputs, output_storage, params=None):
def perform(self, node, inputs, output_storage):
raise NotImplementedError("No Python implementation is provided by this COp.")
......@@ -675,7 +675,7 @@ class _NoPythonExternalCOp(ExternalCOp):
"""
def perform(self, node, inputs, output_storage, params=None):
def perform(self, node, inputs, output_storage):
raise NotImplementedError(
"No Python implementation is provided by this ExternalCOp."
)
......@@ -21,7 +21,7 @@ from numba.extending import box, overload
from pytensor import config
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply, NoParams
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.ifelse import IfElse
......@@ -383,22 +383,11 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
ret_sig = get_numba_type(node.outputs[0].type)
output_types = tuple(out.type for out in node.outputs)
params = node.run_params()
if params is not NoParams:
params_val = dict(node.params_type.filter(params))
def py_perform(inputs):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs, params_val)
return outputs
else:
def py_perform(inputs):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
return outputs
def py_perform(inputs):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
return outputs
if n_outputs == 1:
......
......@@ -90,7 +90,7 @@ class CheckAndRaise(COp):
[value.type()],
)
def perform(self, node, inputs, outputs, params):
def perform(self, node, inputs, outputs):
(out,) = outputs
val, *conds = inputs
out[0] = val
......
......@@ -1658,7 +1658,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rval.lazy = False
return rval
def perform(self, node, inputs, output_storage, params=None):
def perform(self, node, inputs, output_storage):
"""Compute the scan operation in Python.
The `inputs` are packed like this:
......
......@@ -3991,11 +3991,11 @@ class AllocEmpty(COp):
output.tag.nan_guard_mode_check = False
return Apply(self, _shape, [output])
def debug_perform(self, node, inputs, out_, params):
self.perform(node, inputs, out_, params)
def debug_perform(self, node, inputs, out_):
self.perform(node, inputs, out_)
out_[0][0].fill(-123456789)
def perform(self, node, inputs, out_, params):
def perform(self, node, inputs, out_):
(out,) = out_
sh = tuple([int(i) for i in inputs])
if out[0] is None or out[0].shape != sh:
......
......@@ -207,7 +207,7 @@ class Gemv(Op):
return Apply(self, inputs, [y.type()])
def perform(self, node, inputs, out_storage, params=None):
def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs
if (
have_fblas
......@@ -309,7 +309,7 @@ class Ger(Op):
return Apply(self, inputs, [A.type()])
def perform(self, node, inp, out, params=None):
def perform(self, node, inp, out):
cA, calpha, cx, cy = inp
(cZ,) = out
if self.destructive:
......@@ -912,12 +912,12 @@ class Gemm(GemmRelated):
output = z.type()
return Apply(self, inputs, [output])
def perform(self, node, inp, out, params):
def perform(self, node, inp, out):
z, a, x, y, b = inp
(zout,) = out
assert a.shape == ()
assert b.shape == ()
if not params.inplace:
if not self.inplace:
z = z.copy() # the original z will not be changed
if z.shape == ():
z.itemset(z * a + b * np.dot(x, y))
......
......@@ -233,7 +233,7 @@ class DimShuffle(ExternalCOp):
return f"Transpose{{axes={self.shuffle}}}"
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
def perform(self, node, inp, out, params=None):
def perform(self, node, inp, out):
(res,) = inp
(storage,) = out
......
......@@ -145,7 +145,7 @@ class SearchsortedOp(COp):
def infer_shape(self, fgraph, node, shapes):
return [shapes[1]]
def perform(self, node, inputs, output_storage, params):
def perform(self, node, inputs, output_storage):
x = inputs[0]
v = inputs[1]
if len(node.inputs) == 3:
......@@ -154,7 +154,7 @@ class SearchsortedOp(COp):
sorter = None
z = output_storage[0]
z[0] = np.searchsorted(x, v, side=params, sorter=sorter).astype(
z[0] = np.searchsorted(x, v, side=self.side, sorter=sorter).astype(
node.outputs[0].dtype
)
......@@ -310,7 +310,7 @@ class CumOp(COp):
return Apply(self, [x], [out_type])
def perform(self, node, inputs, output_storage, params):
def perform(self, node, inputs, output_storage):
x = inputs[0]
z = output_storage[0]
if self.mode == "add":
......
......@@ -152,9 +152,9 @@ class MaxAndArgmax(COp):
]
return Apply(self, inputs, outputs)
def perform(self, node, inp, outs, params):
def perform(self, node, inp, outs):
x = inp[0]
axes = params
axes = self.axis
max, max_idx = outs
if axes is None:
axes = tuple(range(x.ndim))
......@@ -374,7 +374,7 @@ class Argmax(COp):
"You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format."
)
def perform(self, node, inp, outs, params):
def perform(self, node, inp, outs):
(x,) = inp
axes = self.axis
(max_idx,) = outs
......
......@@ -48,7 +48,7 @@ def local_max_and_argmax(fgraph, node):
If we don't use the argmax, change it to a max only.
"""
if isinstance(node.op, MaxAndArgmax):
axis = node.op.get_params(node)
axis = node.op.axis
if len(fgraph.clients[node.outputs[1]]) == 0:
new = Max(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
......
......@@ -237,7 +237,7 @@ class Shape_i(COp):
raise TypeError(f"{x} has too few dimensions for Shape_i")
return Apply(self, [x], [pytensor.tensor.type.lscalar()])
def perform(self, node, inp, out_, params):
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
if out[0] is None:
......@@ -668,7 +668,7 @@ class Reshape(COp):
return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])
def perform(self, node, inp, out_, params=None):
def perform(self, node, inp, out_):
x, shp = inp
(out,) = out_
if len(shp) != self.ndim:
......
......@@ -2474,7 +2474,7 @@ class AdvancedIncSubtensor1(COp):
def c_code_cache_version(self):
return (8,)
def perform(self, node, inp, out_, params):
def perform(self, node, inp, out_):
x, y, idx = inp
(out,) = out_
if not self.inplace:
......
......@@ -31,7 +31,8 @@ class QuadraticOpFunc(COp):
x = at.as_tensor_variable(x)
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, output_storage, coefficients):
def perform(self, node, inputs, output_storage):
coefficients = self.params_type.filter(self.get_params(node))
x = inputs[0]
y = output_storage[0]
y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c
......@@ -117,7 +118,8 @@ class QuadraticCOpFunc(ExternalCOp):
x = at.as_tensor_variable(x)
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, output_storage, coefficients):
def perform(self, node, inputs, output_storage):
coefficients = self.params_type.filter(self.get_params(node))
x = inputs[0]
y = output_storage[0]
y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c
......
......@@ -117,7 +117,8 @@ class MyOpEnumList(COp):
def make_node(self, a, b):
return Apply(self, [aes.as_scalar(a), aes.as_scalar(b)], [aes.float64()])
def perform(self, node, inputs, outputs, op):
def perform(self, node, inputs, outputs):
op = self.params_type.filter(self.get_params(node))
a, b = inputs
(o,) = outputs
if op == self.params_type.ADD:
......
......@@ -12,6 +12,7 @@ from pytensor.graph.replace import vectorize_node
from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.shape import Shape
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
from pytensor.tensor.utils import _parse_gufunc_signature
......@@ -359,3 +360,13 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
fn = pytensor.function([value, mu, cov], [logp, *dlogp])
benchmark(fn, *test_values)
def test_op_with_params():
matrix_shape_blockwise = Blockwise(core_op=Shape(), signature="(x1,x2)->(s)")
x = tensor("x", shape=(5, None, None), dtype="float64")
x_shape = matrix_shape_blockwise(x)
fn = pytensor.function([x], x_shape)
pytensor.dprint(fn)
# Assert blockwise
print(fn(np.zeros((5, 3, 2))))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论