提交 0f867394 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Some small formatting and style changes

上级 0b78765a
......@@ -49,8 +49,7 @@ from pytensor.tensor.type_other import MakeSlice, NoneConst
def numba_njit(*args, **kwargs):
kwargs = kwargs.copy()
if "cache" not in kwargs:
kwargs["cache"] = config.numba__cache
kwargs.setdefault("cache", config.numba__cache)
if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], **kwargs)(args[0])
......
......@@ -7,6 +7,7 @@ from numba.misc.special import literal_unroll
from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.extra_ops import (
Bartlett,
BroadcastTo,
......@@ -19,7 +20,6 @@ from pytensor.tensor.extra_ops import (
Unique,
UnravelIndex,
)
from pytensor.raise_op import CheckAndRaise
@numba_funcify.register(Bartlett)
......@@ -48,11 +48,13 @@ def numba_funcify_CumOp(op, node, **kwargs):
if mode == "add":
if ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
def cumop(x):
return np.cumsum(x)
else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
def cumop(x):
out_dtype = x.dtype
......@@ -70,11 +72,13 @@ def numba_funcify_CumOp(op, node, **kwargs):
else:
if ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
def cumop(x):
return np.cumprod(x)
else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
def cumop(x):
out_dtype = x.dtype
......
......@@ -144,7 +144,10 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
signature = create_numba_signature(node, force_scalar=True)
return numba_basic.numba_njit(
signature, inline="always", fastmath=config.numba__fastmath, cache=False,
signature,
inline="always",
fastmath=config.numba__fastmath,
cache=False,
)(scalar_op_fn)
......
......@@ -182,7 +182,9 @@ class ConvolutionIndices(Op):
# taking into account multiple
# input features
col = int(
iy * inshp[2] + ix + fmapi * np.prod(inshp[1:], dtype=int)
iy * inshp[2]
+ ix
+ fmapi * np.prod(inshp[1:], dtype=int)
)
# convert oy,ox values to output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论