提交 0f867394 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Some small formatting and style changes

上级 0b78765a
...@@ -49,8 +49,7 @@ from pytensor.tensor.type_other import MakeSlice, NoneConst ...@@ -49,8 +49,7 @@ from pytensor.tensor.type_other import MakeSlice, NoneConst
def numba_njit(*args, **kwargs): def numba_njit(*args, **kwargs):
kwargs = kwargs.copy() kwargs = kwargs.copy()
if "cache" not in kwargs: kwargs.setdefault("cache", config.numba__cache)
kwargs["cache"] = config.numba__cache
if len(args) > 0 and callable(args[0]): if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], **kwargs)(args[0]) return numba.njit(*args[1:], **kwargs)(args[0])
......
...@@ -7,6 +7,7 @@ from numba.misc.special import literal_unroll ...@@ -7,6 +7,7 @@ from numba.misc.special import literal_unroll
from pytensor import config from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.extra_ops import ( from pytensor.tensor.extra_ops import (
Bartlett, Bartlett,
BroadcastTo, BroadcastTo,
...@@ -19,7 +20,6 @@ from pytensor.tensor.extra_ops import ( ...@@ -19,7 +20,6 @@ from pytensor.tensor.extra_ops import (
Unique, Unique,
UnravelIndex, UnravelIndex,
) )
from pytensor.raise_op import CheckAndRaise
@numba_funcify.register(Bartlett) @numba_funcify.register(Bartlett)
...@@ -48,11 +48,13 @@ def numba_funcify_CumOp(op, node, **kwargs): ...@@ -48,11 +48,13 @@ def numba_funcify_CumOp(op, node, **kwargs):
if mode == "add": if mode == "add":
if ndim == 1: if ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath) @numba_basic.numba_njit(fastmath=config.numba__fastmath)
def cumop(x): def cumop(x):
return np.cumsum(x) return np.cumsum(x)
else: else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
def cumop(x): def cumop(x):
out_dtype = x.dtype out_dtype = x.dtype
...@@ -70,11 +72,13 @@ def numba_funcify_CumOp(op, node, **kwargs): ...@@ -70,11 +72,13 @@ def numba_funcify_CumOp(op, node, **kwargs):
else: else:
if ndim == 1: if ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath) @numba_basic.numba_njit(fastmath=config.numba__fastmath)
def cumop(x): def cumop(x):
return np.cumprod(x) return np.cumprod(x)
else: else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
def cumop(x): def cumop(x):
out_dtype = x.dtype out_dtype = x.dtype
......
...@@ -144,7 +144,10 @@ def {scalar_op_fn_name}({', '.join(input_names)}): ...@@ -144,7 +144,10 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
signature = create_numba_signature(node, force_scalar=True) signature = create_numba_signature(node, force_scalar=True)
return numba_basic.numba_njit( return numba_basic.numba_njit(
signature, inline="always", fastmath=config.numba__fastmath, cache=False, signature,
inline="always",
fastmath=config.numba__fastmath,
cache=False,
)(scalar_op_fn) )(scalar_op_fn)
......
...@@ -182,7 +182,9 @@ class ConvolutionIndices(Op): ...@@ -182,7 +182,9 @@ class ConvolutionIndices(Op):
# taking into account multiple # taking into account multiple
# input features # input features
col = int( col = int(
iy * inshp[2] + ix + fmapi * np.prod(inshp[1:], dtype=int) iy * inshp[2]
+ ix
+ fmapi * np.prod(inshp[1:], dtype=int)
) )
# convert oy,ox values to output # convert oy,ox values to output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论