提交 b6874c66 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor vectorize literal encoding helper

上级 c0a4276d
import base64
import pickle
from collections.abc import Callable from collections.abc import Callable
from functools import singledispatch from functools import singledispatch
from numbers import Number from numbers import Number
...@@ -22,7 +20,11 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -22,7 +20,11 @@ from pytensor.link.numba.dispatch.basic import (
numba_njit, numba_njit,
use_optimized_cheap_pass, use_optimized_cheap_pass,
) )
from pytensor.link.numba.dispatch.vectorize_codegen import _jit_options, _vectorized from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
_vectorized,
encode_literals,
)
from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.link.utils import compile_function_src, get_name_for_object
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
AND, AND,
...@@ -478,19 +480,16 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -478,19 +480,16 @@ def numba_funcify_Elemwise(op, node, **kwargs):
**kwargs, **kwargs,
) )
ndim = node.outputs[0].ndim input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs]) output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs]) output_dtypes = tuple(out.type.dtype for out in node.outputs)
output_dtypes = tuple(variable.dtype for variable in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items()) inplace_pattern = tuple(op.inplace_pattern.items())
# numba doesn't support nested literals right now... # numba doesn't support nested literals right now...
input_bc_patterns_enc = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode() input_bc_patterns_enc = encode_literals(input_bc_patterns)
output_bc_patterns_enc = base64.encodebytes( output_bc_patterns_enc = encode_literals(output_bc_patterns)
pickle.dumps(output_bc_patterns) output_dtypes_enc = encode_literals(output_dtypes)
).decode() inplace_pattern_enc = encode_literals(inplace_pattern)
output_dtypes_enc = base64.encodebytes(pickle.dumps(output_dtypes)).decode()
inplace_pattern_enc = base64.encodebytes(pickle.dumps(inplace_pattern)).decode()
def elemwise_wrapper(*inputs): def elemwise_wrapper(*inputs):
return _vectorized( return _vectorized(
......
...@@ -2,6 +2,7 @@ from __future__ import annotations ...@@ -2,6 +2,7 @@ from __future__ import annotations
import base64 import base64
import pickle import pickle
from collections.abc import Sequence
from typing import Any from typing import Any
import numba import numba
...@@ -13,6 +14,10 @@ from numba.core.base import BaseContext ...@@ -13,6 +14,10 @@ from numba.core.base import BaseContext
from numba.np import arrayobj from numba.np import arrayobj
def encode_literals(literals: Sequence) -> str:
return base64.encodebytes(pickle.dumps(literals)).decode()
_jit_options = { _jit_options = {
"fastmath": { "fastmath": {
"arcp", # Allow Reciprocal "arcp", # Allow Reciprocal
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论