提交 47874eb9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Adapt Numba vectorize iterator for RandomVariables

上级 38c04c96
......@@ -62,10 +62,16 @@ def numba_njit(*args, **kwargs):
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)
# Supress caching warnings
# Suppress cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
warnings.filterwarnings(
"ignore",
message='Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals',
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs)" '
"as it uses dynamic globals"
),
category=NumbaWarning,
)
......
......@@ -24,6 +24,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
_vectorized,
encode_literals,
store_core_outputs,
)
from pytensor.link.utils import compile_function_src, get_name_for_object
from pytensor.scalar.basic import (
......@@ -480,10 +481,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
**kwargs,
)
nin = len(node.inputs)
nout = len(node.outputs)
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
output_dtypes = tuple(out.type.dtype for out in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items())
core_output_shapes = tuple(() for _ in range(nout))
# numba doesn't support nested literals right now...
input_bc_patterns_enc = encode_literals(input_bc_patterns)
......@@ -493,12 +499,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
def elemwise_wrapper(*inputs):
return _vectorized(
scalar_op_fn,
core_op_fn,
input_bc_patterns_enc,
output_bc_patterns_enc,
output_dtypes_enc,
inplace_pattern_enc,
(), # constant_inputs
inputs,
core_output_shapes, # core_shapes
None, # size
)
# Pure python implementation, that will be used in tests
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论