Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2fcb9b2c
提交
2fcb9b2c
authored
12月 21, 2022
作者:
Adrian Seyboldt
提交者:
Adrian Seyboldt
1月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix tests and fix scalar numba return types
上级
48f4db7f
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
171 行增加
和
71 行删除
+171
-71
basic.py
pytensor/link/numba/dispatch/basic.py
+3
-3
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+91
-24
elemwise_codegen.py
pytensor/link/numba/dispatch/elemwise_codegen.py
+9
-21
extra_ops.py
pytensor/link/numba/dispatch/extra_ops.py
+1
-0
scalar.py
pytensor/link/numba/dispatch/scalar.py
+3
-0
scan.py
pytensor/link/numba/dispatch/scan.py
+19
-5
linker.py
pytensor/link/numba/linker.py
+2
-2
test_basic.py
tests/link/numba/test_basic.py
+39
-13
test_elemwise.py
tests/link/numba/test_elemwise.py
+3
-3
test_extra_ops.py
tests/link/numba/test_extra_ops.py
+1
-0
没有找到文件。
pytensor/link/numba/dispatch/basic.py
浏览文件 @
2fcb9b2c
...
@@ -204,7 +204,7 @@ enable_slice_boxing()
...
@@ -204,7 +204,7 @@ enable_slice_boxing()
def
to_scalar
(
x
):
def
to_scalar
(
x
):
r
aise
NotImplementedError
()
r
eturn
np
.
asarray
(
x
)
.
item
()
@numba.extending.overload
(
to_scalar
)
@numba.extending.overload
(
to_scalar
)
...
@@ -543,7 +543,7 @@ def {fn_name}({", ".join(input_names)}):
...
@@ -543,7 +543,7 @@ def {fn_name}({", ".join(input_names)}):
{index_prologue}
{index_prologue}
{indices_creation_src}
{indices_creation_src}
{index_body}
{index_body}
return
z
return
np.asarray(z)
"""
"""
return
subtensor_def_src
return
subtensor_def_src
...
@@ -665,7 +665,7 @@ def numba_funcify_Shape_i(op, **kwargs):
...
@@ -665,7 +665,7 @@ def numba_funcify_Shape_i(op, **kwargs):
@numba_njit
@numba_njit
def
shape_i
(
x
):
def
shape_i
(
x
):
return
np
.
shape
(
x
)[
i
]
return
np
.
asarray
(
np
.
shape
(
x
)[
i
])
return
shape_i
return
shape_i
...
...
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
2fcb9b2c
...
@@ -9,6 +9,7 @@ import numba
...
@@ -9,6 +9,7 @@ import numba
import
numpy
as
np
import
numpy
as
np
from
numba
import
TypingError
,
types
from
numba
import
TypingError
,
types
from
numba.core
import
cgutils
from
numba.core
import
cgutils
from
numba.core.extending
import
overload
from
numba.np
import
arrayobj
from
numba.np
import
arrayobj
from
numpy.core.numeric
import
normalize_axis_index
,
normalize_axis_tuple
from
numpy.core.numeric
import
normalize_axis_index
,
normalize_axis_tuple
...
@@ -174,6 +175,7 @@ def create_axis_reducer(
...
@@ -174,6 +175,7 @@ def create_axis_reducer(
ndim
:
int
,
ndim
:
int
,
dtype
:
numba
.
types
.
Type
,
dtype
:
numba
.
types
.
Type
,
keepdims
:
bool
=
False
,
keepdims
:
bool
=
False
,
return_scalar
=
False
,
)
->
numba
.
core
.
dispatcher
.
Dispatcher
:
)
->
numba
.
core
.
dispatcher
.
Dispatcher
:
r"""Create Python function that performs a NumPy-like reduction on a given axis.
r"""Create Python function that performs a NumPy-like reduction on a given axis.
...
@@ -284,6 +286,8 @@ def {reduce_elemwise_fn_name}(x):
...
@@ -284,6 +286,8 @@ def {reduce_elemwise_fn_name}(x):
inplace_update_statement
=
indent
(
inplace_update_statement
,
" "
*
4
*
2
)
inplace_update_statement
=
indent
(
inplace_update_statement
,
" "
*
4
*
2
)
return_expr
=
"res"
if
keepdims
else
"res.item()"
return_expr
=
"res"
if
keepdims
else
"res.item()"
if
not
return_scalar
:
return_expr
=
f
"np.asarray({return_expr})"
reduce_elemwise_def_src
=
f
"""
reduce_elemwise_def_src
=
f
"""
def {reduce_elemwise_fn_name}(x):
def {reduce_elemwise_fn_name}(x):
...
@@ -305,7 +309,13 @@ def {reduce_elemwise_fn_name}(x):
...
@@ -305,7 +309,13 @@ def {reduce_elemwise_fn_name}(x):
def
create_multiaxis_reducer
(
def
create_multiaxis_reducer
(
scalar_op
,
identity
,
axes
,
ndim
,
dtype
,
input_name
=
"input"
scalar_op
,
identity
,
axes
,
ndim
,
dtype
,
input_name
=
"input"
,
return_scalar
=
False
,
):
):
r"""Construct a function that reduces multiple axes.
r"""Construct a function that reduces multiple axes.
...
@@ -336,6 +346,8 @@ def create_multiaxis_reducer(
...
@@ -336,6 +346,8 @@ def create_multiaxis_reducer(
The number of dimensions of the result.
The number of dimensions of the result.
dtype:
dtype:
The data type of the result.
The data type of the result.
return_scalar:
If True, return a scalar, otherwise an array.
Returns
Returns
=======
=======
...
@@ -370,10 +382,17 @@ def create_multiaxis_reducer(
...
@@ -370,10 +382,17 @@ def create_multiaxis_reducer(
)
)
careduce_assign_lines
=
indent
(
"
\n
"
.
join
(
careduce_lines_src
),
" "
*
4
)
careduce_assign_lines
=
indent
(
"
\n
"
.
join
(
careduce_lines_src
),
" "
*
4
)
if
not
return_scalar
:
pre_result
=
"np.asarray"
post_result
=
""
else
:
pre_result
=
"np.asarray"
post_result
=
".item()"
careduce_def_src
=
f
"""
careduce_def_src
=
f
"""
def {careduce_fn_name}({input_name}):
def {careduce_fn_name}({input_name}):
{careduce_assign_lines}
{careduce_assign_lines}
return
np.asarray({var_name})
return
{pre_result}({var_name}){post_result}
"""
"""
careduce_fn
=
compile_function_src
(
careduce_fn
=
compile_function_src
(
...
@@ -383,7 +402,7 @@ def {careduce_fn_name}({input_name}):
...
@@ -383,7 +402,7 @@ def {careduce_fn_name}({input_name}):
return
careduce_fn
return
careduce_fn
def
jit_compile_reducer
(
node
,
fn
,
**
kwds
):
def
jit_compile_reducer
(
node
,
fn
,
*
,
reduce_to_scalar
=
False
,
*
*
kwds
):
"""Compile Python source for reduction loops using additional optimizations.
"""Compile Python source for reduction loops using additional optimizations.
Parameters
Parameters
...
@@ -400,7 +419,7 @@ def jit_compile_reducer(node, fn, **kwds):
...
@@ -400,7 +419,7 @@ def jit_compile_reducer(node, fn, **kwds):
A :func:`numba.njit`-compiled function.
A :func:`numba.njit`-compiled function.
"""
"""
signature
=
create_numba_signature
(
node
,
reduce_to_scalar
=
True
)
signature
=
create_numba_signature
(
node
,
reduce_to_scalar
=
reduce_to_scalar
)
# Eagerly compile the function using increased optimizations. This should
# Eagerly compile the function using increased optimizations. This should
# help improve nested loop reductions.
# help improve nested loop reductions.
...
@@ -618,23 +637,58 @@ def numba_funcify_Elemwise(op, node, **kwargs):
...
@@ -618,23 +637,58 @@ def numba_funcify_Elemwise(op, node, **kwargs):
inplace_pattern
=
tuple
(
op
.
inplace_pattern
.
items
())
inplace_pattern
=
tuple
(
op
.
inplace_pattern
.
items
())
# numba doesn't support nested literals right now...
# numba doesn't support nested literals right now...
input_bc_patterns
=
base64
.
encodebytes
(
pickle
.
dumps
(
input_bc_patterns
))
.
decode
()
input_bc_patterns_enc
=
base64
.
encodebytes
(
pickle
.
dumps
(
input_bc_patterns
))
.
decode
()
output_bc_patterns
=
base64
.
encodebytes
(
pickle
.
dumps
(
output_bc_patterns
))
.
decode
()
output_bc_patterns_enc
=
base64
.
encodebytes
(
output_dtypes
=
base64
.
encodebytes
(
pickle
.
dumps
(
output_dtypes
))
.
decode
()
pickle
.
dumps
(
output_bc_patterns
)
inplace_pattern
=
base64
.
encodebytes
(
pickle
.
dumps
(
inplace_pattern
))
.
decode
()
)
.
decode
()
output_dtypes_enc
=
base64
.
encodebytes
(
pickle
.
dumps
(
output_dtypes
))
.
decode
()
inplace_pattern_enc
=
base64
.
encodebytes
(
pickle
.
dumps
(
inplace_pattern
))
.
decode
()
@numba_njit
def
elemwise_wrapper
(
*
inputs
):
def
elemwise_wrapper
(
*
inputs
):
return
_vectorized
(
return
_vectorized
(
scalar_op_fn
,
scalar_op_fn
,
input_bc_patterns
,
input_bc_patterns
_enc
,
output_bc_patterns
,
output_bc_patterns
_enc
,
output_dtypes
,
output_dtypes
_enc
,
inplace_pattern
,
inplace_pattern
_enc
,
inputs
,
inputs
,
)
)
return
elemwise_wrapper
# Pure python implementation, that will be used in tests
def
elemwise
(
*
inputs
):
inputs
=
[
np
.
asarray
(
input
)
for
input
in
inputs
]
inputs_bc
=
np
.
broadcast_arrays
(
*
inputs
)
shape
=
inputs
[
0
]
.
shape
for
input
,
bc
in
zip
(
inputs
,
input_bc_patterns
):
for
length
,
allow_bc
,
iter_length
in
zip
(
input
.
shape
,
bc
,
shape
):
if
length
==
1
and
shape
and
iter_length
!=
1
and
not
allow_bc
:
raise
ValueError
(
"Broadcast not allowed."
)
outputs
=
[]
for
dtype
in
output_dtypes
:
outputs
.
append
(
np
.
empty
(
shape
,
dtype
=
dtype
))
for
idx
in
np
.
ndindex
(
shape
):
vals
=
[
input
[
idx
]
for
input
in
inputs_bc
]
outs
=
scalar_op_fn
(
*
vals
)
if
not
isinstance
(
outs
,
tuple
):
outs
=
(
outs
,)
for
out
,
out_val
in
zip
(
outputs
,
outs
):
out
[
idx
]
=
out_val
outputs_summed
=
[]
for
output
,
bc
in
zip
(
outputs
,
output_bc_patterns
):
axes
=
tuple
(
np
.
nonzero
(
bc
)[
0
])
outputs_summed
.
append
(
output
.
sum
(
axes
,
keepdims
=
True
))
if
len
(
outputs_summed
)
!=
1
:
return
tuple
(
outputs_summed
)
return
outputs_summed
[
0
]
@overload
(
elemwise
)
def
ov_elemwise
(
*
inputs
):
return
elemwise_wrapper
return
elemwise
@numba_funcify.register
(
Sum
)
@numba_funcify.register
(
Sum
)
...
@@ -643,7 +697,7 @@ def numba_funcify_Sum(op, node, **kwargs):
...
@@ -643,7 +697,7 @@ def numba_funcify_Sum(op, node, **kwargs):
if
axes
is
None
:
if
axes
is
None
:
axes
=
list
(
range
(
node
.
inputs
[
0
]
.
ndim
))
axes
=
list
(
range
(
node
.
inputs
[
0
]
.
ndim
))
axes
=
list
(
axes
)
axes
=
tuple
(
axes
)
ndim_input
=
node
.
inputs
[
0
]
.
ndim
ndim_input
=
node
.
inputs
[
0
]
.
ndim
...
@@ -658,15 +712,16 @@ def numba_funcify_Sum(op, node, **kwargs):
...
@@ -658,15 +712,16 @@ def numba_funcify_Sum(op, node, **kwargs):
@numba_njit
(
fastmath
=
True
)
@numba_njit
(
fastmath
=
True
)
def
impl_sum
(
array
):
def
impl_sum
(
array
):
# TODO The accumulation itself should happen in acc_dtype...
return
np
.
asarray
(
array
.
sum
(),
dtype
=
np_acc_dtype
)
return
np
.
asarray
(
array
.
sum
())
.
astype
(
np_acc_dtype
)
el
se
:
el
if
len
(
axes
)
==
0
:
@numba_njit
(
fastmath
=
True
)
@numba_njit
(
fastmath
=
True
)
def
impl_sum
(
array
):
def
impl_sum
(
array
):
# TODO The accumulation itself should happen in acc_dtype...
return
array
return
array
.
sum
(
axes
)
.
astype
(
np_acc_dtype
)
else
:
impl_sum
=
numba_funcify_CAReduce
(
op
,
node
,
**
kwargs
)
return
impl_sum
return
impl_sum
...
@@ -705,7 +760,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
...
@@ -705,7 +760,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
input_name
=
input_name
,
input_name
=
input_name
,
)
)
careduce_fn
=
jit_compile_reducer
(
node
,
careduce_py_fn
)
careduce_fn
=
jit_compile_reducer
(
node
,
careduce_py_fn
,
reduce_to_scalar
=
False
)
return
careduce_fn
return
careduce_fn
...
@@ -888,7 +943,12 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
...
@@ -888,7 +943,12 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
if
axis
is
not
None
:
if
axis
is
not
None
:
axis
=
normalize_axis_index
(
axis
,
x_at
.
ndim
)
axis
=
normalize_axis_index
(
axis
,
x_at
.
ndim
)
reduce_max_py
=
create_axis_reducer
(
reduce_max_py
=
create_axis_reducer
(
scalar_maximum
,
-
np
.
inf
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
scalar_maximum
,
-
np
.
inf
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
,
)
)
reduce_sum_py
=
create_axis_reducer
(
reduce_sum_py
=
create_axis_reducer
(
add_as
,
0.0
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
add_as
,
0.0
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
...
@@ -935,10 +995,17 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
...
@@ -935,10 +995,17 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
keep_axes
=
tuple
(
i
for
i
in
range
(
x_ndim
)
if
i
not
in
axes
)
keep_axes
=
tuple
(
i
for
i
in
range
(
x_ndim
)
if
i
not
in
axes
)
reduce_max_py_fn
=
create_multiaxis_reducer
(
reduce_max_py_fn
=
create_multiaxis_reducer
(
scalar_maximum
,
-
np
.
inf
,
axes
,
x_ndim
,
x_dtype
scalar_maximum
,
-
np
.
inf
,
axes
,
x_ndim
,
x_dtype
,
return_scalar
=
False
,
)
)
reduce_max
=
jit_compile_reducer
(
reduce_max
=
jit_compile_reducer
(
Apply
(
node
.
op
,
node
.
inputs
,
[
node
.
outputs
[
0
]
.
clone
()]),
reduce_max_py_fn
Apply
(
node
.
op
,
node
.
inputs
,
[
node
.
outputs
[
0
]
.
clone
()]),
reduce_max_py_fn
,
reduce_to_scalar
=
False
,
)
)
reduced_x_ndim
=
x_ndim
-
len
(
axes
)
+
1
reduced_x_ndim
=
x_ndim
-
len
(
axes
)
+
1
...
...
pytensor/link/numba/dispatch/elemwise_codegen.py
浏览文件 @
2fcb9b2c
...
@@ -117,19 +117,6 @@ def make_loop_call(
...
@@ -117,19 +117,6 @@ def make_loop_call(
# context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
# context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
# Lower the code of the scalar function so that we can use it in the inner loop
# Caching is set to false to avoid a numba bug TODO ref?
inner_func
=
context
.
compile_subroutine
(
builder
,
# I don't quite understand why we need to access `dispatcher` here.
# The object does seem to be a dispatcher already? But it is missing
# attributes...
scalar_func
.
dispatcher
,
scalar_signature
,
caching
=
False
,
)
inner
=
inner_func
.
fndesc
# Extract shape and stride information from the array.
# Extract shape and stride information from the array.
# For later use in the loop body to do the indexing
# For later use in the loop body to do the indexing
def
extract_array
(
aryty
,
obj
):
def
extract_array
(
aryty
,
obj
):
...
@@ -191,14 +178,15 @@ def make_loop_call(
...
@@ -191,14 +178,15 @@ def make_loop_call(
# val.set_metadata("noalias", output_scope_set)
# val.set_metadata("noalias", output_scope_set)
input_vals
.
append
(
val
)
input_vals
.
append
(
val
)
# Call scalar function
inner_codegen
=
context
.
get_function
(
scalar_func
,
scalar_signature
)
output_values
=
context
.
call_internal
(
builder
,
if
isinstance
(
inner
,
scalar_signature
.
args
[
0
],
(
types
.
StarArgTuple
,
types
.
StarArgUniTuple
)
scalar_signature
,
):
input_vals
,
input_vals
=
[
context
.
make_tuple
(
builder
,
scalar_signature
.
args
[
0
],
input_vals
)]
)
output_values
=
inner_codegen
(
builder
,
input_vals
)
if
isinstance
(
scalar_signature
.
return_type
,
types
.
Tuple
):
if
isinstance
(
scalar_signature
.
return_type
,
(
types
.
Tuple
,
types
.
UniTuple
)):
output_values
=
cgutils
.
unpack_tuple
(
builder
,
output_values
)
output_values
=
cgutils
.
unpack_tuple
(
builder
,
output_values
)
else
:
else
:
output_values
=
[
output_values
]
output_values
=
[
output_values
]
...
...
pytensor/link/numba/dispatch/extra_ops.py
浏览文件 @
2fcb9b2c
...
@@ -364,6 +364,7 @@ def numba_funcify_BroadcastTo(op, node, **kwargs):
...
@@ -364,6 +364,7 @@ def numba_funcify_BroadcastTo(op, node, **kwargs):
lambda
_
:
0
,
len
(
node
.
inputs
)
-
1
lambda
_
:
0
,
len
(
node
.
inputs
)
-
1
)
)
# TODO broadcastable checks
@numba_basic.numba_njit
@numba_basic.numba_njit
def
broadcast_to
(
x
,
*
shape
):
def
broadcast_to
(
x
,
*
shape
):
scalars_shape
=
create_zeros_tuple
()
scalars_shape
=
create_zeros_tuple
()
...
...
pytensor/link/numba/dispatch/scalar.py
浏览文件 @
2fcb9b2c
...
@@ -38,6 +38,9 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
...
@@ -38,6 +38,9 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
# TODO: Do we need to cache these functions so that we don't end up
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
# compiling the same Numba function over and over again?
if
not
hasattr
(
op
,
"nfunc_spec"
):
return
generate_fallback_impl
(
op
,
node
,
**
kwargs
)
scalar_func_path
=
op
.
nfunc_spec
[
0
]
scalar_func_path
=
op
.
nfunc_spec
[
0
]
scalar_func_numba
=
None
scalar_func_numba
=
None
...
...
pytensor/link/numba/dispatch/scan.py
浏览文件 @
2fcb9b2c
...
@@ -17,7 +17,11 @@ from pytensor.tensor.type import TensorType
...
@@ -17,7 +17,11 @@ from pytensor.tensor.type import TensorType
def
idx_to_str
(
def
idx_to_str
(
array_name
:
str
,
offset
:
int
,
size
:
Optional
[
str
]
=
None
,
idx_symbol
:
str
=
"i"
array_name
:
str
,
offset
:
int
,
size
:
Optional
[
str
]
=
None
,
idx_symbol
:
str
=
"i"
,
allow_scalar
=
False
,
)
->
str
:
)
->
str
:
if
offset
<
0
:
if
offset
<
0
:
indices
=
f
"{idx_symbol} + {array_name}.shape[0] - {offset}"
indices
=
f
"{idx_symbol} + {array_name}.shape[0] - {offset}"
...
@@ -32,7 +36,10 @@ def idx_to_str(
...
@@ -32,7 +36,10 @@ def idx_to_str(
# compensate for this poor `Op`/rewrite design and implementation.
# compensate for this poor `Op`/rewrite design and implementation.
indices
=
f
"({indices})
%
{size}"
indices
=
f
"({indices})
%
{size}"
return
f
"{array_name}[{indices}]"
if
allow_scalar
:
return
f
"{array_name}[{indices}]"
else
:
return
f
"np.asarray({array_name}[{indices}])"
@overload
(
range
)
@overload
(
range
)
...
@@ -115,7 +122,9 @@ def numba_funcify_Scan(op, node, **kwargs):
...
@@ -115,7 +122,9 @@ def numba_funcify_Scan(op, node, **kwargs):
indexed_inner_in_str
=
(
indexed_inner_in_str
=
(
storage_name
storage_name
if
tap_offset
is
None
if
tap_offset
is
None
else
idx_to_str
(
storage_name
,
tap_offset
,
size
=
storage_size_var
)
else
idx_to_str
(
storage_name
,
tap_offset
,
size
=
storage_size_var
,
allow_scalar
=
False
)
)
)
inner_in_exprs
.
append
(
indexed_inner_in_str
)
inner_in_exprs
.
append
(
indexed_inner_in_str
)
...
@@ -232,7 +241,12 @@ def numba_funcify_Scan(op, node, **kwargs):
...
@@ -232,7 +241,12 @@ def numba_funcify_Scan(op, node, **kwargs):
)
)
for
out_tap
in
output_taps
:
for
out_tap
in
output_taps
:
inner_out_to_outer_in_stmts
.
append
(
inner_out_to_outer_in_stmts
.
append
(
idx_to_str
(
storage_name
,
out_tap
,
size
=
storage_size_name
)
idx_to_str
(
storage_name
,
out_tap
,
size
=
storage_size_name
,
allow_scalar
=
True
,
)
)
)
add_output_storage_post_proc_stmt
(
add_output_storage_post_proc_stmt
(
...
@@ -269,7 +283,7 @@ def numba_funcify_Scan(op, node, **kwargs):
...
@@ -269,7 +283,7 @@ def numba_funcify_Scan(op, node, **kwargs):
storage_size_name
=
f
"{outer_in_name}_len"
storage_size_name
=
f
"{outer_in_name}_len"
inner_out_to_outer_in_stmts
.
append
(
inner_out_to_outer_in_stmts
.
append
(
idx_to_str
(
storage_name
,
0
,
size
=
storage_size_name
)
idx_to_str
(
storage_name
,
0
,
size
=
storage_size_name
,
allow_scalar
=
True
)
)
)
add_output_storage_post_proc_stmt
(
storage_name
,
(
0
,),
storage_size_name
)
add_output_storage_post_proc_stmt
(
storage_name
,
(
0
,),
storage_size_name
)
...
...
pytensor/link/numba/linker.py
浏览文件 @
2fcb9b2c
...
@@ -27,9 +27,9 @@ class NumbaLinker(JITLinker):
...
@@ -27,9 +27,9 @@ class NumbaLinker(JITLinker):
return
numba_funcify
(
fgraph
,
**
kwargs
)
return
numba_funcify
(
fgraph
,
**
kwargs
)
def
jit_compile
(
self
,
fn
):
def
jit_compile
(
self
,
fn
):
import
numba
from
pytensor.link.numba.dispatch.basic
import
numba_njit
jitted_fn
=
numba
.
njit
(
fn
)
jitted_fn
=
numba
_
njit
(
fn
)
return
jitted_fn
return
jitted_fn
def
create_thunk_inputs
(
self
,
storage_map
):
def
create_thunk_inputs
(
self
,
storage_map
):
...
...
tests/link/numba/test_basic.py
浏览文件 @
2fcb9b2c
...
@@ -27,6 +27,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic
...
@@ -27,6 +27,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from
pytensor.link.numba.dispatch
import
numba_typify
from
pytensor.link.numba.dispatch
import
numba_typify
from
pytensor.link.numba.linker
import
NumbaLinker
from
pytensor.link.numba.linker
import
NumbaLinker
from
pytensor.raise_op
import
assert_op
from
pytensor.raise_op
import
assert_op
from
pytensor.scalar.basic
import
ScalarOp
,
as_scalar
from
pytensor.tensor
import
blas
from
pytensor.tensor
import
blas
from
pytensor.tensor
import
subtensor
as
at_subtensor
from
pytensor.tensor
import
subtensor
as
at_subtensor
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.elemwise
import
Elemwise
...
@@ -63,6 +64,33 @@ class MySingleOut(Op):
...
@@ -63,6 +64,33 @@ class MySingleOut(Op):
outputs
[
0
][
0
]
=
res
outputs
[
0
][
0
]
=
res
class
ScalarMyMultiOut
(
ScalarOp
):
nin
=
2
nout
=
2
@staticmethod
def
impl
(
a
,
b
):
res1
=
2
*
a
res2
=
2
*
b
return
[
res1
,
res2
]
def
make_node
(
self
,
a
,
b
):
a
=
as_scalar
(
a
)
b
=
as_scalar
(
b
)
return
Apply
(
self
,
[
a
,
b
],
[
a
.
type
(),
b
.
type
()])
def
perform
(
self
,
node
,
inputs
,
outputs
):
res1
,
res2
=
self
.
impl
(
inputs
[
0
],
inputs
[
1
])
outputs
[
0
][
0
]
=
res1
outputs
[
1
][
0
]
=
res2
scalar_my_multi_out
=
Elemwise
(
ScalarMyMultiOut
())
scalar_my_multi_out
.
ufunc
=
ScalarMyMultiOut
.
impl
scalar_my_multi_out
.
ufunc
.
nin
=
2
scalar_my_multi_out
.
ufunc
.
nout
=
2
class
MyMultiOut
(
Op
):
class
MyMultiOut
(
Op
):
nin
=
2
nin
=
2
nout
=
2
nout
=
2
...
@@ -86,7 +114,6 @@ my_multi_out = Elemwise(MyMultiOut())
...
@@ -86,7 +114,6 @@ my_multi_out = Elemwise(MyMultiOut())
my_multi_out
.
ufunc
=
MyMultiOut
.
impl
my_multi_out
.
ufunc
=
MyMultiOut
.
impl
my_multi_out
.
ufunc
.
nin
=
2
my_multi_out
.
ufunc
.
nin
=
2
my_multi_out
.
ufunc
.
nout
=
2
my_multi_out
.
ufunc
.
nout
=
2
opts
=
RewriteDatabaseQuery
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
RewriteDatabaseQuery
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
numba_mode
=
Mode
(
NumbaLinker
(),
opts
)
numba_mode
=
Mode
(
NumbaLinker
(),
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
...
@@ -988,8 +1015,8 @@ def test_config_options_parallel():
...
@@ -988,8 +1015,8 @@ def test_config_options_parallel():
x
=
at
.
dvector
()
x
=
at
.
dvector
()
with
config
.
change_flags
(
numba__vectorize_target
=
"parallel"
):
with
config
.
change_flags
(
numba__vectorize_target
=
"parallel"
):
pytensor_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
pytensor_numba_fn
=
function
([
x
],
at
.
sum
(
x
)
,
mode
=
numba_mode
)
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"
mul
"
]
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"
impl_sum
"
]
assert
numba_mul_fn
.
targetoptions
[
"parallel"
]
is
True
assert
numba_mul_fn
.
targetoptions
[
"parallel"
]
is
True
...
@@ -997,8 +1024,9 @@ def test_config_options_fastmath():
...
@@ -997,8 +1024,9 @@ def test_config_options_fastmath():
x
=
at
.
dvector
()
x
=
at
.
dvector
()
with
config
.
change_flags
(
numba__fastmath
=
True
):
with
config
.
change_flags
(
numba__fastmath
=
True
):
pytensor_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
pytensor_numba_fn
=
function
([
x
],
at
.
sum
(
x
),
mode
=
numba_mode
)
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"mul"
]
print
(
list
(
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
.
keys
()))
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
assert
numba_mul_fn
.
targetoptions
[
"fastmath"
]
is
True
assert
numba_mul_fn
.
targetoptions
[
"fastmath"
]
is
True
...
@@ -1006,16 +1034,14 @@ def test_config_options_cached():
...
@@ -1006,16 +1034,14 @@ def test_config_options_cached():
x
=
at
.
dvector
()
x
=
at
.
dvector
()
with
config
.
change_flags
(
numba__cache
=
True
):
with
config
.
change_flags
(
numba__cache
=
True
):
pytensor_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
pytensor_numba_fn
=
function
([
x
],
at
.
sum
(
x
),
mode
=
numba_mode
)
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"mul"
]
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
assert
not
isinstance
(
assert
not
isinstance
(
numba_mul_fn
.
_cache
,
numba
.
core
.
caching
.
NullCache
)
numba_mul_fn
.
_dispatcher
.
cache
,
numba
.
core
.
caching
.
NullCache
)
with
config
.
change_flags
(
numba__cache
=
False
):
with
config
.
change_flags
(
numba__cache
=
False
):
pytensor_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
pytensor_numba_fn
=
function
([
x
],
at
.
sum
(
x
)
,
mode
=
numba_mode
)
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"
mul
"
]
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"
impl_sum
"
]
assert
isinstance
(
numba_mul_fn
.
_
dispatcher
.
cache
,
numba
.
core
.
caching
.
NullCache
)
assert
isinstance
(
numba_mul_fn
.
_cache
,
numba
.
core
.
caching
.
NullCache
)
def
test_scalar_return_value_conversion
():
def
test_scalar_return_value_conversion
():
...
...
tests/link/numba/test_elemwise.py
浏览文件 @
2fcb9b2c
...
@@ -16,7 +16,7 @@ from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZero
...
@@ -16,7 +16,7 @@ from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZero
from
pytensor.tensor.special
import
LogSoftmax
,
Softmax
,
SoftmaxGrad
from
pytensor.tensor.special
import
LogSoftmax
,
Softmax
,
SoftmaxGrad
from
tests.link.numba.test_basic
import
(
from
tests.link.numba.test_basic
import
(
compare_numba_and_py
,
compare_numba_and_py
,
my_multi_out
,
scalar_
my_multi_out
,
set_test_value
,
set_test_value
,
)
)
...
@@ -99,8 +99,8 @@ rng = np.random.default_rng(42849)
...
@@ -99,8 +99,8 @@ rng = np.random.default_rng(42849)
rng
.
standard_normal
(
100
)
.
astype
(
config
.
floatX
),
rng
.
standard_normal
(
100
)
.
astype
(
config
.
floatX
),
rng
.
standard_normal
(
100
)
.
astype
(
config
.
floatX
),
rng
.
standard_normal
(
100
)
.
astype
(
config
.
floatX
),
],
],
lambda
x
,
y
:
my_multi_out
(
x
,
y
),
lambda
x
,
y
:
scalar_
my_multi_out
(
x
,
y
),
No
tImplementedError
,
No
ne
,
),
),
],
],
)
)
...
...
tests/link/numba/test_extra_ops.py
浏览文件 @
2fcb9b2c
...
@@ -32,6 +32,7 @@ def test_Bartlett(val):
...
@@ -32,6 +32,7 @@ def test_Bartlett(val):
for
i
in
g_fg
.
inputs
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
(
SharedVariable
,
Constant
))
if
not
isinstance
(
i
,
(
SharedVariable
,
Constant
))
],
],
assert_fn
=
lambda
x
,
y
:
np
.
testing
.
assert_allclose
(
x
,
y
,
atol
=
1e-15
),
)
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论