提交 5fa5c9ba authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Speedup python implementation of Blockwise

上级 51cda52b
...@@ -502,7 +502,7 @@ class Op(MetaObject): ...@@ -502,7 +502,7 @@ class Op(MetaObject):
self, self,
node: Apply, node: Apply,
storage_map: StorageMapType, storage_map: StorageMapType,
compute_map: ComputeMapType, compute_map: ComputeMapType | None,
no_recycling: list[Variable], no_recycling: list[Variable],
debug: bool = False, debug: bool = False,
) -> ThunkType: ) -> ThunkType:
...@@ -513,13 +513,26 @@ class Op(MetaObject): ...@@ -513,13 +513,26 @@ class Op(MetaObject):
""" """
node_input_storage = [storage_map[r] for r in node.inputs] node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs] node_output_storage = [storage_map[r] for r in node.outputs]
node_compute_map = [compute_map[r] for r in node.outputs]
if debug and hasattr(self, "debug_perform"): if debug and hasattr(self, "debug_perform"):
p = node.op.debug_perform p = node.op.debug_perform
else: else:
p = node.op.perform p = node.op.perform
if compute_map is None:
@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
):
return p(n, [x[0] for x in i], o)
else:
node_compute_map = [compute_map[r] for r in node.outputs]
@is_thunk_type @is_thunk_type
def rval( def rval(
p=p, p=p,
......
...@@ -39,7 +39,7 @@ class COp(Op, CLinkerOp): ...@@ -39,7 +39,7 @@ class COp(Op, CLinkerOp):
self, self,
node: Apply, node: Apply,
storage_map: StorageMapType, storage_map: StorageMapType,
compute_map: ComputeMapType, compute_map: ComputeMapType | None,
no_recycling: Collection[Variable], no_recycling: Collection[Variable],
) -> CThunkWrapperType: ) -> CThunkWrapperType:
"""Create a thunk for a C implementation. """Create a thunk for a C implementation.
...@@ -86,11 +86,17 @@ class COp(Op, CLinkerOp): ...@@ -86,11 +86,17 @@ class COp(Op, CLinkerOp):
) )
thunk, node_input_filters, node_output_filters = outputs thunk, node_input_filters, node_output_filters = outputs
if compute_map is None:
rval = is_cthunk_wrapper_type(thunk)
else:
cm_entries = [compute_map[o] for o in node.outputs]
@is_cthunk_wrapper_type @is_cthunk_wrapper_type
def rval(): def rval(thunk=thunk, cm_entries=cm_entries):
thunk() thunk()
for o in node.outputs: for entry in cm_entries:
compute_map[o][0] = True entry[0] = True
rval.thunk = thunk rval.thunk = thunk
rval.cthunk = thunk.cthunk rval.cthunk = thunk.cthunk
......
...@@ -12,10 +12,11 @@ from pytensor.gradient import grad ...@@ -12,10 +12,11 @@ from pytensor.gradient import grad
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.signal import convolve1d
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
Cholesky, Cholesky,
Solve, Solve,
...@@ -484,6 +485,26 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm ...@@ -484,6 +485,26 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
benchmark(fn, *test_values) benchmark(fn, *test_values)
def test_small_blockwise_performance(benchmark):
a = dmatrix(shape=(7, 128))
b = dmatrix(shape=(7, 20))
out = convolve1d(a, b, mode="valid")
fn = pytensor.function([a, b], out, trust_input=True)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
rng = np.random.default_rng(495)
a_test = rng.normal(size=a.type.shape)
b_test = rng.normal(size=b.type.shape)
np.testing.assert_allclose(
fn(a_test, b_test),
[
np.convolve(a_test[i], b_test[i], mode="valid")
for i in range(a_test.shape[0])
],
)
benchmark(fn, a_test, b_test)
def test_cop_with_params(): def test_cop_with_params():
matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)") matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论