提交 5fa5c9ba authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Speedup python implementation of Blockwise

上级 51cda52b
...@@ -502,7 +502,7 @@ class Op(MetaObject): ...@@ -502,7 +502,7 @@ class Op(MetaObject):
self, self,
node: Apply, node: Apply,
storage_map: StorageMapType, storage_map: StorageMapType,
compute_map: ComputeMapType, compute_map: ComputeMapType | None,
no_recycling: list[Variable], no_recycling: list[Variable],
debug: bool = False, debug: bool = False,
) -> ThunkType: ) -> ThunkType:
...@@ -513,25 +513,38 @@ class Op(MetaObject): ...@@ -513,25 +513,38 @@ class Op(MetaObject):
""" """
node_input_storage = [storage_map[r] for r in node.inputs] node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs] node_output_storage = [storage_map[r] for r in node.outputs]
node_compute_map = [compute_map[r] for r in node.outputs]
if debug and hasattr(self, "debug_perform"): if debug and hasattr(self, "debug_perform"):
p = node.op.debug_perform p = node.op.debug_perform
else: else:
p = node.op.perform p = node.op.perform
@is_thunk_type if compute_map is None:
def rval(
p=p, @is_thunk_type
i=node_input_storage, def rval(
o=node_output_storage, p=p,
n=node, i=node_input_storage,
cm=node_compute_map, o=node_output_storage,
): n=node,
r = p(n, [x[0] for x in i], o) ):
for entry in cm: return p(n, [x[0] for x in i], o)
entry[0] = True
return r else:
node_compute_map = [compute_map[r] for r in node.outputs]
@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
cm=node_compute_map,
):
r = p(n, [x[0] for x in i], o)
for entry in cm:
entry[0] = True
return r
rval.inputs = node_input_storage rval.inputs = node_input_storage
rval.outputs = node_output_storage rval.outputs = node_output_storage
......
...@@ -39,7 +39,7 @@ class COp(Op, CLinkerOp): ...@@ -39,7 +39,7 @@ class COp(Op, CLinkerOp):
self, self,
node: Apply, node: Apply,
storage_map: StorageMapType, storage_map: StorageMapType,
compute_map: ComputeMapType, compute_map: ComputeMapType | None,
no_recycling: Collection[Variable], no_recycling: Collection[Variable],
) -> CThunkWrapperType: ) -> CThunkWrapperType:
"""Create a thunk for a C implementation. """Create a thunk for a C implementation.
...@@ -86,11 +86,17 @@ class COp(Op, CLinkerOp): ...@@ -86,11 +86,17 @@ class COp(Op, CLinkerOp):
) )
thunk, node_input_filters, node_output_filters = outputs thunk, node_input_filters, node_output_filters = outputs
@is_cthunk_wrapper_type if compute_map is None:
def rval(): rval = is_cthunk_wrapper_type(thunk)
thunk()
for o in node.outputs: else:
compute_map[o][0] = True cm_entries = [compute_map[o] for o in node.outputs]
@is_cthunk_wrapper_type
def rval(thunk=thunk, cm_entries=cm_entries):
thunk()
for entry in cm_entries:
entry[0] = True
rval.thunk = thunk rval.thunk = thunk
rval.cthunk = thunk.cthunk rval.cthunk = thunk.cthunk
......
from collections.abc import Sequence from collections.abc import Callable, Sequence
from typing import Any, cast from typing import Any, cast
import numpy as np import numpy as np
from numpy import broadcast_shapes, empty
from pytensor import config from pytensor import config
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
...@@ -22,12 +23,111 @@ from pytensor.tensor.type import TensorType, tensor ...@@ -22,12 +23,111 @@ from pytensor.tensor.type import TensorType, tensor
from pytensor.tensor.utils import ( from pytensor.tensor.utils import (
_parse_gufunc_signature, _parse_gufunc_signature,
broadcast_static_dim_lengths, broadcast_static_dim_lengths,
faster_broadcast_to,
faster_ndindex,
import_func_from_string, import_func_from_string,
safe_signature, safe_signature,
) )
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
def _vectorize_node_perform(
core_node: Apply,
batch_bcast_patterns: Sequence[tuple[bool, ...]],
batch_ndim: int,
impl: str | None,
) -> Callable:
"""Creates a vectorized `perform` function for a given core node.
Similar behavior of np.vectorize, but specialized for PyTensor Blockwise Op.
"""
storage_map = {var: [None] for var in core_node.inputs + core_node.outputs}
core_thunk = core_node.op.make_thunk(core_node, storage_map, None, [], impl=impl)
single_in = len(core_node.inputs) == 1
core_input_storage = [storage_map[inp] for inp in core_node.inputs]
core_output_storage = [storage_map[out] for out in core_node.outputs]
core_storage = core_input_storage + core_output_storage
def vectorized_perform(
*args,
batch_bcast_patterns=batch_bcast_patterns,
batch_ndim=batch_ndim,
single_in=single_in,
core_thunk=core_thunk,
core_input_storage=core_input_storage,
core_output_storage=core_output_storage,
core_storage=core_storage,
):
if single_in:
batch_shape = args[0].shape[:batch_ndim]
else:
_check_runtime_broadcast_core(args, batch_bcast_patterns, batch_ndim)
batch_shape = broadcast_shapes(*(arg.shape[:batch_ndim] for arg in args))
args = list(args)
for i, arg in enumerate(args):
if arg.shape[:batch_ndim] != batch_shape:
args[i] = faster_broadcast_to(
arg, batch_shape + arg.shape[batch_ndim:]
)
ndindex_iterator = faster_ndindex(batch_shape)
# Call once to get the output shapes
try:
# TODO: Pass core shape as input like BlockwiseWithCoreShape does?
index0 = next(ndindex_iterator)
except StopIteration:
raise NotImplementedError("vectorize with zero size not implemented")
else:
for core_input, arg in zip(core_input_storage, args):
core_input[0] = np.asarray(arg[index0])
core_thunk()
outputs = tuple(
empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype)
for core_output in core_output_storage
)
for output, core_output in zip(outputs, core_output_storage):
output[index0] = core_output[0]
for index in ndindex_iterator:
for core_input, arg in zip(core_input_storage, args):
core_input[0] = np.asarray(arg[index])
core_thunk()
for output, core_output in zip(outputs, core_output_storage):
output[index] = core_output[0]
# Clear storage
for core_val in core_storage:
core_val[0] = None
return outputs
return vectorized_perform
def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_ndim):
# strict=None because we are in a hot loop
# We zip together the dimension lengths of each input and their broadcast patterns
for dim_lengths_and_bcast in zip(
*[
zip(input.shape[:batch_ndim], batch_bcast_pattern)
for input, batch_bcast_pattern in zip(
numerical_inputs, batch_bcast_patterns
)
],
):
# If for any dimension where an entry has dim_length != 1,
# and another a dim_length of 1 and broadcastable=False, we have runtime broadcasting.
if (
any(d != 1 for d, _ in dim_lengths_and_bcast)
and (1, False) in dim_lengths_and_bcast
):
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
class Blockwise(Op): class Blockwise(Op):
"""Generalizes a core `Op` to work with batched dimensions. """Generalizes a core `Op` to work with batched dimensions.
...@@ -308,7 +408,7 @@ class Blockwise(Op): ...@@ -308,7 +408,7 @@ class Blockwise(Op):
return rval return rval
def _create_node_gufunc(self, node) -> None: def _create_node_gufunc(self, node: Apply, impl) -> Callable:
"""Define (or retrieve) the node gufunc used in `perform`. """Define (or retrieve) the node gufunc used in `perform`.
If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly. If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
...@@ -316,83 +416,66 @@ class Blockwise(Op): ...@@ -316,83 +416,66 @@ class Blockwise(Op):
The gufunc is stored in the tag of the node. The gufunc is stored in the tag of the node.
""" """
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None) batch_ndim = self.batch_ndim(node)
batch_bcast_patterns = [
if gufunc_spec is not None: inp.type.broadcastable[:batch_ndim] for inp in node.inputs
gufunc = import_func_from_string(gufunc_spec[0]) ]
if gufunc is None: if (
gufunc_spec := self.gufunc_spec
or getattr(self.core_op, "gufunc_spec", None)
) is not None:
core_func = import_func_from_string(gufunc_spec[0])
if core_func is None:
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}") raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
else: if len(node.outputs) == 1:
# Wrap core_op perform method in numpy vectorize
n_outs = len(self.outputs_sig) def gufunc(
core_node = self._create_dummy_core_node(node.inputs) *inputs,
inner_outputs_storage = [[None] for _ in range(n_outs)] batch_bcast_patterns=batch_bcast_patterns,
batch_ndim=batch_ndim,
def core_func( ):
*inner_inputs, _check_runtime_broadcast_core(
core_node=core_node, inputs, batch_bcast_patterns, batch_ndim
inner_outputs_storage=inner_outputs_storage, )
): return (core_func(*inputs),)
self.core_op.perform( else:
core_node,
[np.asarray(inp) for inp in inner_inputs],
inner_outputs_storage,
)
if n_outs == 1:
return inner_outputs_storage[0][0]
else:
return tuple(r[0] for r in inner_outputs_storage)
gufunc = np.vectorize(core_func, signature=self.signature) def gufunc(
*inputs,
batch_bcast_patterns=batch_bcast_patterns,
batch_ndim=batch_ndim,
):
_check_runtime_broadcast_core(
inputs, batch_bcast_patterns, batch_ndim
)
return core_func(*inputs)
else:
core_node = self._create_dummy_core_node(node.inputs) # type: ignore
gufunc = _vectorize_node_perform(
core_node,
batch_bcast_patterns=batch_bcast_patterns,
batch_ndim=self.batch_ndim(node),
impl=impl,
)
node.tag.gufunc = gufunc return gufunc
def _check_runtime_broadcast(self, node, inputs): def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self.batch_ndim(node) batch_ndim = self.batch_ndim(node)
batch_bcast = [pt_inp.type.broadcastable[:batch_ndim] for pt_inp in node.inputs]
_check_runtime_broadcast_core(inputs, batch_bcast, batch_ndim)
# strict=False because we are in a hot loop def prepare_node(self, node, storage_map, compute_map, impl=None):
for dims_and_bcast in zip( node.tag.gufunc = self._create_node_gufunc(node, impl=impl)
*[
zip(
input.shape[:batch_ndim],
sinput.type.broadcastable[:batch_ndim],
strict=False,
)
for input, sinput in zip(inputs, node.inputs, strict=False)
],
strict=False,
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
gufunc = getattr(node.tag, "gufunc", None) try:
if gufunc is None:
# Cache it once per node
self._create_node_gufunc(node)
gufunc = node.tag.gufunc gufunc = node.tag.gufunc
except AttributeError:
self._check_runtime_broadcast(node, inputs) gufunc = node.tag.gufunc = self._create_node_gufunc(node, impl=None)
for out_storage, result in zip(output_storage, gufunc(*inputs)):
res = gufunc(*inputs) out_storage[0] = result
if not isinstance(res, tuple):
res = (res,)
# strict=False because we are in a hot loop
for node_out, out_storage, r in zip(
node.outputs, output_storage, res, strict=False
):
out_dtype = getattr(node_out, "dtype", None)
if out_dtype and out_dtype != r.dtype:
r = np.asarray(r, dtype=out_dtype)
out_storage[0] = r
def __str__(self): def __str__(self):
if self.name is None: if self.name is None:
......
...@@ -12,10 +12,11 @@ from pytensor.gradient import grad ...@@ -12,10 +12,11 @@ from pytensor.gradient import grad
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.signal import convolve1d
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
Cholesky, Cholesky,
Solve, Solve,
...@@ -484,6 +485,26 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm ...@@ -484,6 +485,26 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
benchmark(fn, *test_values) benchmark(fn, *test_values)
def test_small_blockwise_performance(benchmark):
a = dmatrix(shape=(7, 128))
b = dmatrix(shape=(7, 20))
out = convolve1d(a, b, mode="valid")
fn = pytensor.function([a, b], out, trust_input=True)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
rng = np.random.default_rng(495)
a_test = rng.normal(size=a.type.shape)
b_test = rng.normal(size=b.type.shape)
np.testing.assert_allclose(
fn(a_test, b_test),
[
np.convolve(a_test[i], b_test[i], mode="valid")
for i in range(a_test.shape[0])
],
)
benchmark(fn, a_test, b_test)
def test_cop_with_params(): def test_cop_with_params():
matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)") matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论