提交 6557682b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow Blockwise to compile C inner thunks

上级 261aaf37
...@@ -16,6 +16,7 @@ from pytensor.graph.replace import ( ...@@ -16,6 +16,7 @@ from pytensor.graph.replace import (
_vectorize_not_needed, _vectorize_not_needed,
vectorize_graph, vectorize_graph,
) )
from pytensor.link.c.op import COp
from pytensor.scalar import ScalarType from pytensor.scalar import ScalarType
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
...@@ -43,7 +44,18 @@ def _vectorize_node_perform( ...@@ -43,7 +44,18 @@ def _vectorize_node_perform(
""" """
storage_map = {var: [None] for var in core_node.inputs + core_node.outputs} storage_map = {var: [None] for var in core_node.inputs + core_node.outputs}
core_thunk = core_node.op.make_thunk(core_node, storage_map, None, [], impl=impl) try:
core_thunk = core_node.op.make_thunk(
core_node, storage_map, None, [], impl=impl
)
except NotImplementedError:
if impl == "c":
# Try again with py impl
core_thunk = core_node.op.make_thunk(
core_node, storage_map, None, [], impl="py"
)
else:
raise
single_in = len(core_node.inputs) == 1 single_in = len(core_node.inputs) == 1
core_input_storage = [storage_map[inp] for inp in core_node.inputs] core_input_storage = [storage_map[inp] for inp in core_node.inputs]
core_output_storage = [storage_map[out] for out in core_node.outputs] core_output_storage = [storage_map[out] for out in core_node.outputs]
...@@ -128,7 +140,7 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_ ...@@ -128,7 +140,7 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_
) )
class Blockwise(Op): class Blockwise(COp):
"""Generalizes a core `Op` to work with batched dimensions. """Generalizes a core `Op` to work with batched dimensions.
TODO: Dispatch JAX (should be easy with the vectorize macro) TODO: Dispatch JAX (should be easy with the vectorize macro)
...@@ -483,6 +495,14 @@ class Blockwise(Op): ...@@ -483,6 +495,14 @@ class Blockwise(Op):
else: else:
return self.name return self.name
def c_code(self, *args, **kwargs):
# Blockwise is a C_Op just so we can propagate compilation mode to the inner Op.
# It doesn't itself have a C implementation yet.
raise NotImplementedError()
def c_code_cache_version(self):
return (-1,)
@_vectorize_node.register(Op) @_vectorize_node.register(Op)
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论