提交 5fa5c9ba authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Speedup python implementation of Blockwise

上级 51cda52b
......@@ -502,7 +502,7 @@ class Op(MetaObject):
self,
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
compute_map: ComputeMapType | None,
no_recycling: list[Variable],
debug: bool = False,
) -> ThunkType:
......@@ -513,25 +513,38 @@ class Op(MetaObject):
"""
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_compute_map = [compute_map[r] for r in node.outputs]
if debug and hasattr(self, "debug_perform"):
p = node.op.debug_perform
else:
p = node.op.perform
@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
cm=node_compute_map,
):
r = p(n, [x[0] for x in i], o)
for entry in cm:
entry[0] = True
return r
if compute_map is None:
@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
):
return p(n, [x[0] for x in i], o)
else:
node_compute_map = [compute_map[r] for r in node.outputs]
@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
cm=node_compute_map,
):
r = p(n, [x[0] for x in i], o)
for entry in cm:
entry[0] = True
return r
rval.inputs = node_input_storage
rval.outputs = node_output_storage
......
......@@ -39,7 +39,7 @@ class COp(Op, CLinkerOp):
self,
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
compute_map: ComputeMapType | None,
no_recycling: Collection[Variable],
) -> CThunkWrapperType:
"""Create a thunk for a C implementation.
......@@ -86,11 +86,17 @@ class COp(Op, CLinkerOp):
)
thunk, node_input_filters, node_output_filters = outputs
@is_cthunk_wrapper_type
def rval():
thunk()
for o in node.outputs:
compute_map[o][0] = True
if compute_map is None:
rval = is_cthunk_wrapper_type(thunk)
else:
cm_entries = [compute_map[o] for o in node.outputs]
@is_cthunk_wrapper_type
def rval(thunk=thunk, cm_entries=cm_entries):
thunk()
for entry in cm_entries:
entry[0] = True
rval.thunk = thunk
rval.cthunk = thunk.cthunk
......
......@@ -12,10 +12,11 @@ from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector
from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.signal import convolve1d
from pytensor.tensor.slinalg import (
Cholesky,
Solve,
......@@ -484,6 +485,26 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
benchmark(fn, *test_values)
def test_small_blockwise_performance(benchmark):
a = dmatrix(shape=(7, 128))
b = dmatrix(shape=(7, 20))
out = convolve1d(a, b, mode="valid")
fn = pytensor.function([a, b], out, trust_input=True)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
rng = np.random.default_rng(495)
a_test = rng.normal(size=a.type.shape)
b_test = rng.normal(size=b.type.shape)
np.testing.assert_allclose(
fn(a_test, b_test),
[
np.convolve(a_test[i], b_test[i], mode="valid")
for i in range(a_test.shape[0])
],
)
benchmark(fn, a_test, b_test)
def test_cop_with_params():
matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论