提交 5fa5c9ba authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Speedup python implementation of Blockwise

上级 51cda52b
......@@ -502,7 +502,7 @@ class Op(MetaObject):
self,
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
compute_map: ComputeMapType | None,
no_recycling: list[Variable],
debug: bool = False,
) -> ThunkType:
......@@ -513,25 +513,38 @@ class Op(MetaObject):
"""
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_compute_map = [compute_map[r] for r in node.outputs]
if debug and hasattr(self, "debug_perform"):
p = node.op.debug_perform
else:
p = node.op.perform
@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
cm=node_compute_map,
):
r = p(n, [x[0] for x in i], o)
for entry in cm:
entry[0] = True
return r
if compute_map is None:
@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
):
return p(n, [x[0] for x in i], o)
else:
node_compute_map = [compute_map[r] for r in node.outputs]
@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
cm=node_compute_map,
):
r = p(n, [x[0] for x in i], o)
for entry in cm:
entry[0] = True
return r
rval.inputs = node_input_storage
rval.outputs = node_output_storage
......
......@@ -39,7 +39,7 @@ class COp(Op, CLinkerOp):
self,
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
compute_map: ComputeMapType | None,
no_recycling: Collection[Variable],
) -> CThunkWrapperType:
"""Create a thunk for a C implementation.
......@@ -86,11 +86,17 @@ class COp(Op, CLinkerOp):
)
thunk, node_input_filters, node_output_filters = outputs
@is_cthunk_wrapper_type
def rval():
thunk()
for o in node.outputs:
compute_map[o][0] = True
if compute_map is None:
rval = is_cthunk_wrapper_type(thunk)
else:
cm_entries = [compute_map[o] for o in node.outputs]
@is_cthunk_wrapper_type
def rval(thunk=thunk, cm_entries=cm_entries):
thunk()
for entry in cm_entries:
entry[0] = True
rval.thunk = thunk
rval.cthunk = thunk.cthunk
......
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import Any, cast
import numpy as np
from numpy import broadcast_shapes, empty
from pytensor import config
from pytensor.compile.builders import OpFromGraph
......@@ -22,12 +23,111 @@ from pytensor.tensor.type import TensorType, tensor
from pytensor.tensor.utils import (
_parse_gufunc_signature,
broadcast_static_dim_lengths,
faster_broadcast_to,
faster_ndindex,
import_func_from_string,
safe_signature,
)
from pytensor.tensor.variable import TensorVariable
def _vectorize_node_perform(
core_node: Apply,
batch_bcast_patterns: Sequence[tuple[bool, ...]],
batch_ndim: int,
impl: str | None,
) -> Callable:
"""Creates a vectorized `perform` function for a given core node.
Similar behavior of np.vectorize, but specialized for PyTensor Blockwise Op.
"""
storage_map = {var: [None] for var in core_node.inputs + core_node.outputs}
core_thunk = core_node.op.make_thunk(core_node, storage_map, None, [], impl=impl)
single_in = len(core_node.inputs) == 1
core_input_storage = [storage_map[inp] for inp in core_node.inputs]
core_output_storage = [storage_map[out] for out in core_node.outputs]
core_storage = core_input_storage + core_output_storage
def vectorized_perform(
*args,
batch_bcast_patterns=batch_bcast_patterns,
batch_ndim=batch_ndim,
single_in=single_in,
core_thunk=core_thunk,
core_input_storage=core_input_storage,
core_output_storage=core_output_storage,
core_storage=core_storage,
):
if single_in:
batch_shape = args[0].shape[:batch_ndim]
else:
_check_runtime_broadcast_core(args, batch_bcast_patterns, batch_ndim)
batch_shape = broadcast_shapes(*(arg.shape[:batch_ndim] for arg in args))
args = list(args)
for i, arg in enumerate(args):
if arg.shape[:batch_ndim] != batch_shape:
args[i] = faster_broadcast_to(
arg, batch_shape + arg.shape[batch_ndim:]
)
ndindex_iterator = faster_ndindex(batch_shape)
# Call once to get the output shapes
try:
# TODO: Pass core shape as input like BlockwiseWithCoreShape does?
index0 = next(ndindex_iterator)
except StopIteration:
raise NotImplementedError("vectorize with zero size not implemented")
else:
for core_input, arg in zip(core_input_storage, args):
core_input[0] = np.asarray(arg[index0])
core_thunk()
outputs = tuple(
empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype)
for core_output in core_output_storage
)
for output, core_output in zip(outputs, core_output_storage):
output[index0] = core_output[0]
for index in ndindex_iterator:
for core_input, arg in zip(core_input_storage, args):
core_input[0] = np.asarray(arg[index])
core_thunk()
for output, core_output in zip(outputs, core_output_storage):
output[index] = core_output[0]
# Clear storage
for core_val in core_storage:
core_val[0] = None
return outputs
return vectorized_perform
def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_ndim):
# strict=None because we are in a hot loop
# We zip together the dimension lengths of each input and their broadcast patterns
for dim_lengths_and_bcast in zip(
*[
zip(input.shape[:batch_ndim], batch_bcast_pattern)
for input, batch_bcast_pattern in zip(
numerical_inputs, batch_bcast_patterns
)
],
):
# If for any dimension where an entry has dim_length != 1,
# and another a dim_length of 1 and broadcastable=False, we have runtime broadcasting.
if (
any(d != 1 for d, _ in dim_lengths_and_bcast)
and (1, False) in dim_lengths_and_bcast
):
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
class Blockwise(Op):
"""Generalizes a core `Op` to work with batched dimensions.
......@@ -308,7 +408,7 @@ class Blockwise(Op):
return rval
def _create_node_gufunc(self, node) -> None:
def _create_node_gufunc(self, node: Apply, impl) -> Callable:
"""Define (or retrieve) the node gufunc used in `perform`.
If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
......@@ -316,83 +416,66 @@ class Blockwise(Op):
The gufunc is stored in the tag of the node.
"""
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)
if gufunc_spec is not None:
gufunc = import_func_from_string(gufunc_spec[0])
if gufunc is None:
batch_ndim = self.batch_ndim(node)
batch_bcast_patterns = [
inp.type.broadcastable[:batch_ndim] for inp in node.inputs
]
if (
gufunc_spec := self.gufunc_spec
or getattr(self.core_op, "gufunc_spec", None)
) is not None:
core_func = import_func_from_string(gufunc_spec[0])
if core_func is None:
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
else:
# Wrap core_op perform method in numpy vectorize
n_outs = len(self.outputs_sig)
core_node = self._create_dummy_core_node(node.inputs)
inner_outputs_storage = [[None] for _ in range(n_outs)]
def core_func(
*inner_inputs,
core_node=core_node,
inner_outputs_storage=inner_outputs_storage,
):
self.core_op.perform(
core_node,
[np.asarray(inp) for inp in inner_inputs],
inner_outputs_storage,
)
if n_outs == 1:
return inner_outputs_storage[0][0]
else:
return tuple(r[0] for r in inner_outputs_storage)
if len(node.outputs) == 1:
def gufunc(
*inputs,
batch_bcast_patterns=batch_bcast_patterns,
batch_ndim=batch_ndim,
):
_check_runtime_broadcast_core(
inputs, batch_bcast_patterns, batch_ndim
)
return (core_func(*inputs),)
else:
gufunc = np.vectorize(core_func, signature=self.signature)
def gufunc(
*inputs,
batch_bcast_patterns=batch_bcast_patterns,
batch_ndim=batch_ndim,
):
_check_runtime_broadcast_core(
inputs, batch_bcast_patterns, batch_ndim
)
return core_func(*inputs)
else:
core_node = self._create_dummy_core_node(node.inputs) # type: ignore
gufunc = _vectorize_node_perform(
core_node,
batch_bcast_patterns=batch_bcast_patterns,
batch_ndim=self.batch_ndim(node),
impl=impl,
)
node.tag.gufunc = gufunc
return gufunc
def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self.batch_ndim(node)
batch_bcast = [pt_inp.type.broadcastable[:batch_ndim] for pt_inp in node.inputs]
_check_runtime_broadcast_core(inputs, batch_bcast, batch_ndim)
# strict=False because we are in a hot loop
for dims_and_bcast in zip(
*[
zip(
input.shape[:batch_ndim],
sinput.type.broadcastable[:batch_ndim],
strict=False,
)
for input, sinput in zip(inputs, node.inputs, strict=False)
],
strict=False,
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
def prepare_node(self, node, storage_map, compute_map, impl=None):
node.tag.gufunc = self._create_node_gufunc(node, impl=impl)
def perform(self, node, inputs, output_storage):
gufunc = getattr(node.tag, "gufunc", None)
if gufunc is None:
# Cache it once per node
self._create_node_gufunc(node)
try:
gufunc = node.tag.gufunc
self._check_runtime_broadcast(node, inputs)
res = gufunc(*inputs)
if not isinstance(res, tuple):
res = (res,)
# strict=False because we are in a hot loop
for node_out, out_storage, r in zip(
node.outputs, output_storage, res, strict=False
):
out_dtype = getattr(node_out, "dtype", None)
if out_dtype and out_dtype != r.dtype:
r = np.asarray(r, dtype=out_dtype)
out_storage[0] = r
except AttributeError:
gufunc = node.tag.gufunc = self._create_node_gufunc(node, impl=None)
for out_storage, result in zip(output_storage, gufunc(*inputs)):
out_storage[0] = result
def __str__(self):
if self.name is None:
......
......@@ -12,10 +12,11 @@ from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector
from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.signal import convolve1d
from pytensor.tensor.slinalg import (
Cholesky,
Solve,
......@@ -484,6 +485,26 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
benchmark(fn, *test_values)
def test_small_blockwise_performance(benchmark):
a = dmatrix(shape=(7, 128))
b = dmatrix(shape=(7, 20))
out = convolve1d(a, b, mode="valid")
fn = pytensor.function([a, b], out, trust_input=True)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
rng = np.random.default_rng(495)
a_test = rng.normal(size=a.type.shape)
b_test = rng.normal(size=b.type.shape)
np.testing.assert_allclose(
fn(a_test, b_test),
[
np.convolve(a_test[i], b_test[i], mode="valid")
for i in range(a_test.shape[0])
],
)
benchmark(fn, a_test, b_test)
def test_cop_with_params():
matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论