提交 e39fda37 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make blockwise perform method node dependent

上级 a377c22d
from collections.abc import Sequence
from copy import copy
from typing import Any, cast
import numpy as np
......@@ -79,7 +78,6 @@ class Blockwise(Op):
self.name = name
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec
self._gufunc = None
if destroy_map is not None:
self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map:
......@@ -91,11 +89,6 @@ class Blockwise(Op):
super().__init__(**kwargs)
def __getstate__(self):
d = copy(self.__dict__)
d["_gufunc"] = None
return d
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
core_input_types = []
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
......@@ -296,32 +289,40 @@ class Blockwise(Op):
return rval
def _create_gufunc(self, node):
def _create_node_gufunc(self, node) -> None:
"""Define (or retrieve) the node gufunc used in `perform`.
If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.
The gufunc is stored in the tag of the node.
"""
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)
if gufunc_spec is not None:
self._gufunc = import_func_from_string(gufunc_spec[0])
if self._gufunc:
return self._gufunc
else:
gufunc = import_func_from_string(gufunc_spec[0])
if gufunc is None:
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
n_outs = len(self.outputs_sig)
core_node = self._create_dummy_core_node(node.inputs)
else:
# Wrap core_op perform method in numpy vectorize
n_outs = len(self.outputs_sig)
core_node = self._create_dummy_core_node(node.inputs)
def core_func(*inner_inputs):
inner_outputs = [[None] for _ in range(n_outs)]
def core_func(*inner_inputs):
inner_outputs = [[None] for _ in range(n_outs)]
inner_inputs = [np.asarray(inp) for inp in inner_inputs]
self.core_op.perform(core_node, inner_inputs, inner_outputs)
inner_inputs = [np.asarray(inp) for inp in inner_inputs]
self.core_op.perform(core_node, inner_inputs, inner_outputs)
if len(inner_outputs) == 1:
return inner_outputs[0][0]
else:
return tuple(r[0] for r in inner_outputs)
if len(inner_outputs) == 1:
return inner_outputs[0][0]
else:
return tuple(r[0] for r in inner_outputs)
gufunc = np.vectorize(core_func, signature=self.signature)
self._gufunc = np.vectorize(core_func, signature=self.signature)
return self._gufunc
node.tag.gufunc = gufunc
def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self.batch_ndim(node)
......@@ -340,10 +341,12 @@ class Blockwise(Op):
)
def perform(self, node, inputs, output_storage):
gufunc = self._gufunc
gufunc = getattr(node.tag, "gufunc", None)
if gufunc is None:
gufunc = self._create_gufunc(node)
# Cache it once per node
self._create_node_gufunc(node)
gufunc = node.tag.gufunc
self._check_runtime_broadcast(node, inputs)
......
......@@ -28,6 +28,41 @@ from pytensor.tensor.slinalg import (
from pytensor.tensor.utils import _parse_gufunc_signature
def test_perform_method_per_node():
"""Confirm that Blockwise uses one perform method per node.
This is important if the perform method requires node information (such as dtypes)
"""
class NodeDependentPerformOp(Op):
def make_node(self, x):
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, outputs):
[x] = inputs
if node.inputs[0].type.dtype.startswith("float"):
y = x + 1
else:
y = x - 1
outputs[0][0] = y
blockwise_op = Blockwise(core_op=NodeDependentPerformOp(), signature="()->()")
x = tensor("x", shape=(3,), dtype="float32")
y = tensor("y", shape=(3,), dtype="int32")
out_x = blockwise_op(x)
out_y = blockwise_op(y)
fn = pytensor.function([x, y], [out_x, out_y])
[op1, op2] = [node.op for node in fn.maker.fgraph.apply_nodes]
# Confirm both nodes have the same Op
assert op1 is blockwise_op
assert op1 is op2
res_out_x, res_out_y = fn(np.zeros(3, dtype="float32"), np.zeros(3, dtype="int32"))
np.testing.assert_array_equal(res_out_x, np.ones(3, dtype="float32"))
np.testing.assert_array_equal(res_out_y, -np.ones(3, dtype="int32"))
def test_vectorize_blockwise():
mat = tensor(shape=(None, None))
tns = tensor(shape=(None, None, None))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论