提交 47874eb9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Adapt Numba vectorize iterator for RandomVariables

上级 38c04c96
...@@ -62,10 +62,16 @@ def numba_njit(*args, **kwargs): ...@@ -62,10 +62,16 @@ def numba_njit(*args, **kwargs):
kwargs.setdefault("no_cpython_wrapper", True) kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True) kwargs.setdefault("no_cfunc_wrapper", True)
# Supress caching warnings # Suppress cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
message='Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals', message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs)" '
"as it uses dynamic globals"
),
category=NumbaWarning, category=NumbaWarning,
) )
......
...@@ -24,6 +24,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( ...@@ -24,6 +24,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options, _jit_options,
_vectorized, _vectorized,
encode_literals, encode_literals,
store_core_outputs,
) )
from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.link.utils import compile_function_src, get_name_for_object
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
...@@ -480,10 +481,15 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -480,10 +481,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
**kwargs, **kwargs,
) )
nin = len(node.inputs)
nout = len(node.outputs)
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs]) input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs]) output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
output_dtypes = tuple(out.type.dtype for out in node.outputs) output_dtypes = tuple(out.type.dtype for out in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items()) inplace_pattern = tuple(op.inplace_pattern.items())
core_output_shapes = tuple(() for _ in range(nout))
# numba doesn't support nested literals right now... # numba doesn't support nested literals right now...
input_bc_patterns_enc = encode_literals(input_bc_patterns) input_bc_patterns_enc = encode_literals(input_bc_patterns)
...@@ -493,12 +499,15 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -493,12 +499,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
def elemwise_wrapper(*inputs): def elemwise_wrapper(*inputs):
return _vectorized( return _vectorized(
scalar_op_fn, core_op_fn,
input_bc_patterns_enc, input_bc_patterns_enc,
output_bc_patterns_enc, output_bc_patterns_enc,
output_dtypes_enc, output_dtypes_enc,
inplace_pattern_enc, inplace_pattern_enc,
(), # constant_inputs
inputs, inputs,
core_output_shapes, # core_shapes
None, # size
) )
# Pure python implementation, that will be used in tests # Pure python implementation, that will be used in tests
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论