提交 b6874c66 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor vectorize literal encoding helper

上级 c0a4276d
import base64
import pickle
from collections.abc import Callable
from functools import singledispatch
from numbers import Number
......@@ -22,7 +20,11 @@ from pytensor.link.numba.dispatch.basic import (
numba_njit,
use_optimized_cheap_pass,
)
from pytensor.link.numba.dispatch.vectorize_codegen import _jit_options, _vectorized
from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
_vectorized,
encode_literals,
)
from pytensor.link.utils import compile_function_src, get_name_for_object
from pytensor.scalar.basic import (
AND,
......@@ -478,19 +480,16 @@ def numba_funcify_Elemwise(op, node, **kwargs):
**kwargs,
)
ndim = node.outputs[0].ndim
output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs])
input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs])
output_dtypes = tuple(variable.dtype for variable in node.outputs)
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
output_dtypes = tuple(out.type.dtype for out in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items())
# numba doesn't support nested literals right now...
input_bc_patterns_enc = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode()
output_bc_patterns_enc = base64.encodebytes(
pickle.dumps(output_bc_patterns)
).decode()
output_dtypes_enc = base64.encodebytes(pickle.dumps(output_dtypes)).decode()
inplace_pattern_enc = base64.encodebytes(pickle.dumps(inplace_pattern)).decode()
input_bc_patterns_enc = encode_literals(input_bc_patterns)
output_bc_patterns_enc = encode_literals(output_bc_patterns)
output_dtypes_enc = encode_literals(output_dtypes)
inplace_pattern_enc = encode_literals(inplace_pattern)
def elemwise_wrapper(*inputs):
return _vectorized(
......
......@@ -2,6 +2,7 @@ from __future__ import annotations
import base64
import pickle
from collections.abc import Sequence
from typing import Any
import numba
......@@ -13,6 +14,10 @@ from numba.core.base import BaseContext
from numba.np import arrayobj
def encode_literals(literals: Sequence) -> str:
return base64.encodebytes(pickle.dumps(literals)).decode()
_jit_options = {
"fastmath": {
"arcp", # Allow Reciprocal
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论