提交 e39fda37 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make blockwise perform method node dependent

上级 a377c22d
from collections.abc import Sequence from collections.abc import Sequence
from copy import copy
from typing import Any, cast from typing import Any, cast
import numpy as np import numpy as np
...@@ -79,7 +78,6 @@ class Blockwise(Op): ...@@ -79,7 +78,6 @@ class Blockwise(Op):
self.name = name self.name = name
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec self.gufunc_spec = gufunc_spec
self._gufunc = None
if destroy_map is not None: if destroy_map is not None:
self.destroy_map = destroy_map self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map: if self.destroy_map != core_op.destroy_map:
...@@ -91,11 +89,6 @@ class Blockwise(Op): ...@@ -91,11 +89,6 @@ class Blockwise(Op):
super().__init__(**kwargs) super().__init__(**kwargs)
def __getstate__(self):
d = copy(self.__dict__)
d["_gufunc"] = None
return d
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
core_input_types = [] core_input_types = []
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
...@@ -296,16 +289,23 @@ class Blockwise(Op): ...@@ -296,16 +289,23 @@ class Blockwise(Op):
return rval return rval
def _create_gufunc(self, node): def _create_node_gufunc(self, node) -> None:
"""Define (or retrieve) the node gufunc used in `perform`.
If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.
The gufunc is stored in the tag of the node.
"""
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None) gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)
if gufunc_spec is not None: if gufunc_spec is not None:
self._gufunc = import_func_from_string(gufunc_spec[0]) gufunc = import_func_from_string(gufunc_spec[0])
if self._gufunc: if gufunc is None:
return self._gufunc
else:
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}") raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
else:
# Wrap core_op perform method in numpy vectorize
n_outs = len(self.outputs_sig) n_outs = len(self.outputs_sig)
core_node = self._create_dummy_core_node(node.inputs) core_node = self._create_dummy_core_node(node.inputs)
...@@ -320,8 +320,9 @@ class Blockwise(Op): ...@@ -320,8 +320,9 @@ class Blockwise(Op):
else: else:
return tuple(r[0] for r in inner_outputs) return tuple(r[0] for r in inner_outputs)
self._gufunc = np.vectorize(core_func, signature=self.signature) gufunc = np.vectorize(core_func, signature=self.signature)
return self._gufunc
node.tag.gufunc = gufunc
def _check_runtime_broadcast(self, node, inputs): def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self.batch_ndim(node) batch_ndim = self.batch_ndim(node)
...@@ -340,10 +341,12 @@ class Blockwise(Op): ...@@ -340,10 +341,12 @@ class Blockwise(Op):
) )
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
gufunc = self._gufunc gufunc = getattr(node.tag, "gufunc", None)
if gufunc is None: if gufunc is None:
gufunc = self._create_gufunc(node) # Cache it once per node
self._create_node_gufunc(node)
gufunc = node.tag.gufunc
self._check_runtime_broadcast(node, inputs) self._check_runtime_broadcast(node, inputs)
......
...@@ -28,6 +28,41 @@ from pytensor.tensor.slinalg import ( ...@@ -28,6 +28,41 @@ from pytensor.tensor.slinalg import (
from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.utils import _parse_gufunc_signature
def test_perform_method_per_node():
"""Confirm that Blockwise uses one perform method per node.
This is important if the perform method requires node information (such as dtypes)
"""
class NodeDependentPerformOp(Op):
def make_node(self, x):
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, outputs):
[x] = inputs
if node.inputs[0].type.dtype.startswith("float"):
y = x + 1
else:
y = x - 1
outputs[0][0] = y
blockwise_op = Blockwise(core_op=NodeDependentPerformOp(), signature="()->()")
x = tensor("x", shape=(3,), dtype="float32")
y = tensor("y", shape=(3,), dtype="int32")
out_x = blockwise_op(x)
out_y = blockwise_op(y)
fn = pytensor.function([x, y], [out_x, out_y])
[op1, op2] = [node.op for node in fn.maker.fgraph.apply_nodes]
# Confirm both nodes have the same Op
assert op1 is blockwise_op
assert op1 is op2
res_out_x, res_out_y = fn(np.zeros(3, dtype="float32"), np.zeros(3, dtype="int32"))
np.testing.assert_array_equal(res_out_x, np.ones(3, dtype="float32"))
np.testing.assert_array_equal(res_out_y, -np.ones(3, dtype="int32"))
def test_vectorize_blockwise(): def test_vectorize_blockwise():
mat = tensor(shape=(None, None)) mat = tensor(shape=(None, None))
tns = tensor(shape=(None, None, None)) tns = tensor(shape=(None, None, None))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论