Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
06c5acdf
提交
06c5acdf
authored
1月 26, 2024
作者:
Virgile Andreani
提交者:
Ricardo Vieira
1月 27, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix UP031 with `ruff --unsafe-fixes`
上级
cada5ad2
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
36 个修改的文件
包含
598 行增加
和
606 行删除
+598
-606
pyproject.toml
pyproject.toml
+1
-1
builders.py
pytensor/compile/builders.py
+1
-1
basic.py
pytensor/link/c/basic.py
+56
-62
cmodule.py
pytensor/link/c/cmodule.py
+5
-6
interface.py
pytensor/link/c/interface.py
+13
-11
op.py
pytensor/link/c/op.py
+16
-14
params_type.py
pytensor/link/c/params_type.py
+69
-70
type.py
pytensor/link/c/type.py
+53
-51
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+2
-2
printing.py
pytensor/printing.py
+1
-2
basic.py
pytensor/scalar/basic.py
+24
-20
math.py
pytensor/scalar/math.py
+10
-10
op.py
pytensor/scan/op.py
+2
-4
basic.py
pytensor/sparse/basic.py
+0
-0
rewriting.py
pytensor/sparse/rewriting.py
+0
-0
basic.py
pytensor/tensor/basic.py
+0
-0
blas.py
pytensor/tensor/blas.py
+58
-59
blas_c.py
pytensor/tensor/blas_c.py
+0
-0
blas_headers.py
pytensor/tensor/blas_headers.py
+23
-23
elemwise.py
pytensor/tensor/elemwise.py
+41
-41
elemwise_cgen.py
pytensor/tensor/elemwise_cgen.py
+23
-23
extra_ops.py
pytensor/tensor/extra_ops.py
+63
-63
math.py
pytensor/tensor/math.py
+5
-5
subtensor.py
pytensor/tensor/subtensor.py
+0
-0
type.py
pytensor/tensor/type.py
+69
-70
basic.py
pytensor/typed_list/basic.py
+54
-58
type.py
pytensor/typed_list/type.py
+9
-10
test_debugmode.py
tests/compile/test_debugmode.py
+0
-0
test_basic.py
tests/link/c/test_basic.py
+0
-0
test_op.py
tests/link/c/test_op.py
+0
-0
test_params_type.py
tests/link/c/test_params_type.py
+0
-0
test_type.py
tests/link/c/test_type.py
+0
-0
c_conv3d_corr3d_ref.py
tests/tensor/conv/c_conv3d_corr3d_ref.py
+0
-0
c_conv_corr_ref.py
tests/tensor/conv/c_conv_corr_ref.py
+0
-0
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+0
-0
utils.py
tests/tensor/utils.py
+0
-0
没有找到文件。
pyproject.toml
浏览文件 @
06c5acdf
...
@@ -130,7 +130,7 @@ disable = ["C0330", "C0326"]
...
@@ -130,7 +130,7 @@ disable = ["C0330", "C0326"]
[tool.ruff]
[tool.ruff]
select
=
[
"C"
,
"E"
,
"F"
,
"I"
,
"UP"
,
"W"
]
select
=
[
"C"
,
"E"
,
"F"
,
"I"
,
"UP"
,
"W"
]
ignore
=
[
"C408"
,
"C901"
,
"E501"
,
"E741"
,
"UP031"
]
ignore
=
[
"C408"
,
"C901"
,
"E501"
,
"E741"
]
exclude
=
[
"doc/"
,
"pytensor/_version.py"
,
"bin/pytensor_cache.py"
]
exclude
=
[
"doc/"
,
"pytensor/_version.py"
,
"bin/pytensor_cache.py"
]
...
...
pytensor/compile/builders.py
浏览文件 @
06c5acdf
...
@@ -465,7 +465,7 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -465,7 +465,7 @@ class OpFromGraph(Op, HasInnerGraph):
def
__str__
(
self
):
def
__str__
(
self
):
name
=
self
.
__class__
.
__name__
if
self
.
name
is
None
else
self
.
name
name
=
self
.
__class__
.
__name__
if
self
.
name
is
None
else
self
.
name
is_inline
=
self
.
is_inline
is_inline
=
self
.
is_inline
return
"
%(name)
s{inline=
%(is_inline)
s}"
%
locals
(
)
return
"
{name}{{inline={is_inline}}}"
.
format
(
**
locals
()
)
@config.change_flags
(
compute_test_value
=
"off"
)
@config.change_flags
(
compute_test_value
=
"off"
)
def
_recompute_lop_op
(
self
):
def
_recompute_lop_op
(
self
):
...
...
pytensor/link/c/basic.py
浏览文件 @
06c5acdf
...
@@ -250,33 +250,30 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -250,33 +250,30 @@ def struct_gen(args, struct_builders, blocks, sub):
# that holds the type, the value and the traceback. After storing
# that holds the type, the value and the traceback. After storing
# the error, we return the failure code so we know which code
# the error, we return the failure code so we know which code
# block failed.
# block failed.
do_return
=
(
do_return
=
"""
"""
if ({failure_var}) {{
if (
%(failure_var)
s) {
// When there is a failure, this code puts the exception
// When there is a failure, this code puts the exception
// in __ERROR.
// in __ERROR.
PyObject* err_type = NULL;
PyObject* err_type = NULL;
PyObject* err_msg = NULL;
PyObject* err_msg = NULL;
PyObject* err_traceback = NULL;
PyObject* err_traceback = NULL;
PyErr_Fetch(&err_type, &err_msg, &err_traceback);
PyErr_Fetch(&err_type, &err_msg, &err_traceback);
if (!err_type) {
err_type = Py_None;Py_INCREF(Py_None);
}
if (!err_type) {
{err_type = Py_None;Py_INCREF(Py_None);}
}
if (!err_msg) {
err_msg = Py_None; Py_INCREF(Py_None);
}
if (!err_msg) {
{err_msg = Py_None; Py_INCREF(Py_None);}
}
if (!err_traceback) {
err_traceback = Py_None; Py_INCREF(Py_None);
}
if (!err_traceback) {
{err_traceback = Py_None; Py_INCREF(Py_None);}
}
PyObject* old_err_type = PyList_GET_ITEM(__ERROR, 0);
PyObject* old_err_type = PyList_GET_ITEM(__ERROR, 0);
PyObject* old_err_msg = PyList_GET_ITEM(__ERROR, 1);
PyObject* old_err_msg = PyList_GET_ITEM(__ERROR, 1);
PyObject* old_err_traceback = PyList_GET_ITEM(__ERROR, 2);
PyObject* old_err_traceback = PyList_GET_ITEM(__ERROR, 2);
PyList_SET_ITEM(__ERROR, 0, err_type);
PyList_SET_ITEM(__ERROR, 0, err_type);
PyList_SET_ITEM(__ERROR, 1, err_msg);
PyList_SET_ITEM(__ERROR, 1, err_msg);
PyList_SET_ITEM(__ERROR, 2, err_traceback);
PyList_SET_ITEM(__ERROR, 2, err_traceback);
{
Py_XDECREF(old_err_type);
}
{
{Py_XDECREF(old_err_type);}
}
{
Py_XDECREF(old_err_msg);
}
{
{Py_XDECREF(old_err_msg);}
}
{
Py_XDECREF(old_err_traceback);
}
{
{Py_XDECREF(old_err_traceback);}
}
}
}
}
// The failure code is returned to index what code block failed.
// The failure code is returned to index what code block failed.
return
%(failure_var)
s;
return {failure_var};
"""
"""
.
format
(
**
sub
)
%
sub
)
sub
=
dict
(
sub
)
sub
=
dict
(
sub
)
sub
.
update
(
locals
())
sub
.
update
(
locals
())
...
@@ -284,16 +281,15 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -284,16 +281,15 @@ def struct_gen(args, struct_builders, blocks, sub):
# TODO: add some error checking to make sure storage_<x> are
# TODO: add some error checking to make sure storage_<x> are
# 1-element lists and __ERROR is a 3-elements list.
# 1-element lists and __ERROR is a 3-elements list.
struct_code
=
(
struct_code
=
"""
"""
namespace {{
namespace {
struct {name} {{
struct
%(name)
s {
PyObject* __ERROR;
PyObject* __ERROR;
%(storage_decl)
s
{storage_decl}
%(struct_decl)
s
{struct_decl}
%(name)
s()
{
{name}() {
{
// This is only somewhat safe because we:
// This is only somewhat safe because we:
// 1) Are not a virtual class
// 1) Are not a virtual class
// 2) Do not use any virtual classes in the members
// 2) Do not use any virtual classes in the members
...
@@ -306,32 +302,30 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -306,32 +302,30 @@ def struct_gen(args, struct_builders, blocks, sub):
#ifndef PYTENSOR_DONT_MEMSET_STRUCT
#ifndef PYTENSOR_DONT_MEMSET_STRUCT
memset(this, 0, sizeof(*this));
memset(this, 0, sizeof(*this));
#endif
#endif
}
}
}
~
%(name)
s(void)
{
~
{name}(void) {
{
cleanup();
cleanup();
}
}
}
int init(PyObject* __ERROR,
%(args_decl)
s)
{
int init(PyObject* __ERROR,
{args_decl}) {
{
%(storage_incref)
s
{storage_incref}
%(storage_set)
s
{storage_set}
%(struct_init_head)
s
{struct_init_head}
this->__ERROR = __ERROR;
this->__ERROR = __ERROR;
return 0;
return 0;
}
}}
void cleanup(void) {
void cleanup(void) {{
%(struct_cleanup)
s
{struct_cleanup}
%(storage_decref)
s
{storage_decref}
}
}}
int run(void) {
int run(void) {{
int
%(failure_var)
s = 0;
int {failure_var} = 0;
%(behavior)
s
{behavior}
%(do_return)
s
{do_return}
}
}}
};
}};
}
}}
"""
"""
.
format
(
**
sub
)
%
sub
)
return
struct_code
return
struct_code
...
@@ -380,9 +374,9 @@ def get_c_init(fgraph, r, name, sub):
...
@@ -380,9 +374,9 @@ def get_c_init(fgraph, r, name, sub):
pre
=
(
pre
=
(
""
""
"""
"""
py_
%(name)
s
= Py_None;
py_
{name}
= Py_None;
{
Py_XINCREF(py_
%(name)
s);
}
{
{Py_XINCREF(py_{name});}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
)
)
return
pre
+
r
.
type
.
c_init
(
name
,
sub
)
return
pre
+
r
.
type
.
c_init
(
name
,
sub
)
...
@@ -418,9 +412,9 @@ def get_c_extract(fgraph, r, name, sub):
...
@@ -418,9 +412,9 @@ def get_c_extract(fgraph, r, name, sub):
c_extract
=
r
.
type
.
c_extract
(
name
,
sub
,
False
)
c_extract
=
r
.
type
.
c_extract
(
name
,
sub
,
False
)
pre
=
"""
pre
=
"""
py_
%(name)
s = PyList_GET_ITEM(storage_
%(name)
s
, 0);
py_
{name} = PyList_GET_ITEM(storage_{name}
, 0);
{
Py_XINCREF(py_
%(name)
s);
}
{
{Py_XINCREF(py_{name});}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
return
pre
+
c_extract
return
pre
+
c_extract
...
@@ -447,9 +441,9 @@ def get_c_extract_out(fgraph, r, name, sub):
...
@@ -447,9 +441,9 @@ def get_c_extract_out(fgraph, r, name, sub):
c_extract
=
r
.
type
.
c_extract_out
(
name
,
sub
,
check_input
,
check_broadcast
=
False
)
c_extract
=
r
.
type
.
c_extract_out
(
name
,
sub
,
check_input
,
check_broadcast
=
False
)
pre
=
"""
pre
=
"""
py_
%(name)
s = PyList_GET_ITEM(storage_
%(name)
s
, 0);
py_
{name} = PyList_GET_ITEM(storage_{name}
, 0);
{
Py_XINCREF(py_
%(name)
s);
}
{
{Py_XINCREF(py_{name});}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
return
pre
+
c_extract
return
pre
+
c_extract
...
@@ -459,8 +453,8 @@ def get_c_cleanup(fgraph, r, name, sub):
...
@@ -459,8 +453,8 @@ def get_c_cleanup(fgraph, r, name, sub):
"""
"""
post
=
"""
post
=
"""
{
Py_XDECREF(py_
%(name)
s);
}
{
{Py_XDECREF(py_{name});}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
return
r
.
type
.
c_cleanup
(
name
,
sub
)
+
post
return
r
.
type
.
c_cleanup
(
name
,
sub
)
+
post
...
@@ -470,14 +464,14 @@ def get_c_sync(fgraph, r, name, sub):
...
@@ -470,14 +464,14 @@ def get_c_sync(fgraph, r, name, sub):
"""
"""
return
"""
return
"""
if (!
%(failure_var)
s)
{
if (!
{failure_var}) {
{
%(sync)
s
{sync}
PyObject* old = PyList_GET_ITEM(storage_
%(name)
s
, 0);
PyObject* old = PyList_GET_ITEM(storage_
{name}
, 0);
{
Py_XINCREF(py_
%(name)
s);
}
{
{Py_XINCREF(py_{name});}
}
PyList_SET_ITEM(storage_
%(name)
s, 0, py_
%(name)
s
);
PyList_SET_ITEM(storage_
{name}, 0, py_{name}
);
{
Py_XDECREF(old);
}
{
{Py_XDECREF(old);}
}
}
}
}
"""
%
dict
(
sync
=
r
.
type
.
c_sync
(
name
,
sub
),
name
=
name
,
**
sub
)
"""
.
format
(
**
dict
(
sync
=
r
.
type
.
c_sync
(
name
,
sub
),
name
=
name
,
**
sub
)
)
def
apply_policy
(
fgraph
,
policy
,
r
,
name
,
sub
):
def
apply_policy
(
fgraph
,
policy
,
r
,
name
,
sub
):
...
...
pytensor/link/c/cmodule.py
浏览文件 @
06c5acdf
...
@@ -1950,14 +1950,13 @@ class Compiler:
...
@@ -1950,14 +1950,13 @@ class Compiler:
code
=
(
code
=
(
"""
"""
%(preamble)
s
{preamble}
int main(int argc, char** argv)
int main(int argc, char** argv)
{
{
{
%(body)
s
{body}
return 0;
return 0;
}
}}
"""
"""
.
format
(
**
locals
())
%
locals
()
)
.
encode
()
)
.
encode
()
return
cls
.
_try_compile_tmp
(
return
cls
.
_try_compile_tmp
(
code
,
code
,
...
...
pytensor/link/c/interface.py
浏览文件 @
06c5acdf
...
@@ -558,18 +558,20 @@ class CLinkerType(CLinkerObject):
...
@@ -558,18 +558,20 @@ class CLinkerType(CLinkerObject):
"""
"""
return
"""
return
"""
if (py_
%(name)
s
== Py_None)
if (py_
{name}
== Py_None)
{
{
{
%(c_init_code)
s
{c_init_code}
}
}
}
else
else
{
{{
%(c_extract_code)
s
{c_extract_code}
}
}}
"""
%
dict
(
"""
.
format
(
name
=
name
,
**
dict
(
c_init_code
=
self
.
c_init
(
name
,
sub
),
name
=
name
,
c_extract_code
=
self
.
c_extract
(
name
,
sub
,
check_input
),
c_init_code
=
self
.
c_init
(
name
,
sub
),
c_extract_code
=
self
.
c_extract
(
name
,
sub
,
check_input
),
)
)
)
def
c_cleanup
(
self
,
name
:
str
,
sub
:
dict
[
str
,
str
])
->
str
:
def
c_cleanup
(
self
,
name
:
str
,
sub
:
dict
[
str
,
str
])
->
str
:
...
...
pytensor/link/c/op.py
浏览文件 @
06c5acdf
...
@@ -596,20 +596,22 @@ class ExternalCOp(COp):
...
@@ -596,20 +596,22 @@ class ExternalCOp(COp):
# Generate the C code
# Generate the C code
return
"""
return
"""
%(define_macros)
s
{define_macros}
{
{{
if (
%(func_name)
s(
%(func_args)
s
%(params)
s) != 0) {
if ({func_name}({func_args}{params}) != 0) {{
%(fail)
s
{fail}
}
}}
}
}}
%(undef_macros)
s
{undef_macros}
"""
%
dict
(
"""
.
format
(
func_name
=
self
.
func_name
,
**
dict
(
fail
=
sub
[
"fail"
],
func_name
=
self
.
func_name
,
params
=
params
,
fail
=
sub
[
"fail"
],
func_args
=
self
.
format_c_function_args
(
inp
,
out
),
params
=
params
,
define_macros
=
define_macros
,
func_args
=
self
.
format_c_function_args
(
inp
,
out
),
undef_macros
=
undef_macros
,
define_macros
=
define_macros
,
undef_macros
=
undef_macros
,
)
)
)
else
:
else
:
if
"code"
in
self
.
code_sections
:
if
"code"
in
self
.
code_sections
:
...
...
pytensor/link/c/params_type.py
浏览文件 @
06c5acdf
...
@@ -359,8 +359,7 @@ class ParamsType(CType):
...
@@ -359,8 +359,7 @@ class ParamsType(CType):
type_name
=
type_instance
.
__class__
.
__name__
type_name
=
type_instance
.
__class__
.
__name__
if
not
isinstance
(
type_instance
,
CType
):
if
not
isinstance
(
type_instance
,
CType
):
raise
TypeError
(
raise
TypeError
(
'ParamsType: attribute "
%
s" should inherit from PyTensor CType, got "
%
s".'
f
'ParamsType: attribute "{attribute_name}" should inherit from PyTensor CType, got "{type_name}".'
%
(
attribute_name
,
type_name
)
)
)
self
.
length
=
len
(
kwargs
)
self
.
length
=
len
(
kwargs
)
...
@@ -723,15 +722,11 @@ class ParamsType(CType):
...
@@ -723,15 +722,11 @@ class ParamsType(CType):
c_cleanup_list
.
append
(
type_instance
.
c_cleanup
(
attribute_name
,
sub
))
c_cleanup_list
.
append
(
type_instance
.
c_cleanup
(
attribute_name
,
sub
))
c_extract_list
.
append
(
c_extract_list
.
append
(
"""
f
"""
void extract_
%(attribute_name)
s(PyObject* py_
%(attribute_name)
s)
{
void extract_
{attribute_name}(PyObject* py_{attribute_name}) {
{
%(extract_code)
s
{type_instance.c_extract(attribute_name, sub)}
}
}
}
"""
"""
%
{
"attribute_name"
:
attribute_name
,
"extract_code"
:
type_instance
.
c_extract
(
attribute_name
,
sub
),
}
)
)
struct_declare
=
"
\n
"
.
join
(
c_declare_list
)
struct_declare
=
"
\n
"
.
join
(
c_declare_list
)
...
@@ -759,55 +754,57 @@ class ParamsType(CType):
...
@@ -759,55 +754,57 @@ class ParamsType(CType):
)
)
)
)
final_struct_code
=
"""
final_struct_code
=
"""
/** ParamsType
%(struct_name)
s
**/
/** ParamsType
{struct_name}
**/
#ifndef
%(struct_name_defined)
s
#ifndef
{struct_name_defined}
#define
%(struct_name_defined)
s
#define
{struct_name_defined}
struct
%(struct_name)
s
{
struct
{struct_name} {
{
/* Attributes, */
/* Attributes, */
int
%(struct_name)
s
_error;
int
{struct_name}
_error;
%(struct_declare)
s
{struct_declare}
/* Constructor. */
/* Constructor. */
%(struct_name)
s()
{
{struct_name}() {
{
%(struct_name)
s
_error = 0;
{struct_name}
_error = 0;
%(struct_init)
s
{struct_init}
}
}
}
/* Destructor. */
/* Destructor. */
~
%(struct_name)
s()
{
~
{struct_name}() {
{
// cleanup() is defined below.
// cleanup() is defined below.
cleanup();
cleanup();
}
}
}
/* Cleanup method. */
/* Cleanup method. */
void cleanup() {
void cleanup() {
{
%(struct_cleanup)
s
{struct_cleanup}
}
}
}
/* Extraction methods. */
/* Extraction methods. */
%(struct_extract)
s
{struct_extract}
/* Extract method. */
/* Extract method. */
%(struct_extract_method)
s
{struct_extract_method}
/* Other methods. */
/* Other methods. */
void setErrorOccurred() {
void setErrorOccurred() {
{
++
%(struct_name)
s
_error;
++
{struct_name}
_error;
}
}
}
int errorOccurred() {
int errorOccurred() {
{
return
%(struct_name)
s
_error;
return
{struct_name}
_error;
}
}
}
};
}
}
;
#endif
#endif
/** End ParamsType
%(struct_name)
s **/
/** End ParamsType {struct_name} **/
"""
%
dict
(
"""
.
format
(
struct_name_defined
=
struct_name_defined
,
**
dict
(
struct_name
=
struct_name
,
struct_name_defined
=
struct_name_defined
,
struct_declare
=
struct_declare
,
struct_name
=
struct_name
,
struct_init
=
struct_init
,
struct_declare
=
struct_declare
,
struct_cleanup
=
struct_cleanup
,
struct_init
=
struct_init
,
struct_extract
=
struct_extract
,
struct_cleanup
=
struct_cleanup
,
struct_extract_method
=
struct_extract_method
,
struct_extract
=
struct_extract
,
struct_extract_method
=
struct_extract_method
,
)
)
)
return
sorted
(
c_support_code_set
)
+
[
final_struct_code
]
return
sorted
(
c_support_code_set
)
+
[
final_struct_code
]
...
@@ -822,8 +819,8 @@ class ParamsType(CType):
...
@@ -822,8 +819,8 @@ class ParamsType(CType):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
return
"""
return
"""
%(struct_name)
s*
%(name)
s
;
{struct_name}* {name}
;
"""
%
dict
(
struct_name
=
self
.
name
,
name
=
name
)
"""
.
format
(
**
dict
(
struct_name
=
self
.
name
,
name
=
name
)
)
def
c_init
(
self
,
name
,
sub
):
def
c_init
(
self
,
name
,
sub
):
# NB: It seems c_init() is not called for an op param.
# NB: It seems c_init() is not called for an op param.
...
@@ -841,35 +838,37 @@ class ParamsType(CType):
...
@@ -841,35 +838,37 @@ class ParamsType(CType):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
return
"""
/* Seems c_init() is not called for a op param. So I call `new` here. */
/* Seems c_init() is not called for a op param. So I call `new` here. */
%(name)
s = new
%(struct_name)
s
;
{name} = new {struct_name}
;
{ // This need a separate namespace for Clinker
{
{
// This need a separate namespace for Clinker
const char* fields[] = {
%(fields_list)
s
};
const char* fields[] = {
{{fields_list}}
};
if (py_
%(name)
s == Py_None)
{
if (py_
{name} == Py_None) {
{
PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None.");
PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None.");
%(fail)
s
{fail}
}
}
}
for (int i = 0; i <
%(length)
s; ++i)
{
for (int i = 0; i <
{length}; ++i) {
{
PyObject* o = PyDict_GetItemString(py_
%(name)
s
, fields[i]);
PyObject* o = PyDict_GetItemString(py_
{name}
, fields[i]);
if (o == NULL) {
if (o == NULL) {
{
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute
\\
"
%
%
s
\\
" in object.", fields[i]);
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute
\\
"
%
s
\\
" in object.", fields[i]);
%(fail)
s
{fail}
}
}
}
%(name)
s
->extract(o, i);
{name}
->extract(o, i);
if (
%(name)
s->errorOccurred())
{
if (
{name}->errorOccurred()) {
{
/* The extract code from attribute type should have already raised a Python exception,
/* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */
* so we just print the attribute name in stderr. */
fprintf(stderr, "
\\
nParamsType: error when extracting value for attribute
\\
"
%%
s
\\
".
\\
n", fields[i]);
fprintf(stderr, "
\\
nParamsType: error when extracting value for attribute
\\
"
%
s
\\
".
\\
n", fields[i]);
%(fail)
s
{fail}
}
}}
}
}}
}
}}
"""
%
dict
(
"""
.
format
(
name
=
name
,
**
dict
(
struct_name
=
self
.
name
,
name
=
name
,
length
=
self
.
length
,
struct_name
=
self
.
name
,
fail
=
sub
[
"fail"
],
length
=
self
.
length
,
fields_list
=
'"
%
s"'
%
'", "'
.
join
(
self
.
fields
),
fail
=
sub
[
"fail"
],
fields_list
=
'"
%
s"'
%
'", "'
.
join
(
self
.
fields
),
)
)
)
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
...
...
pytensor/link/c/type.py
浏览文件 @
06c5acdf
...
@@ -99,11 +99,11 @@ class Generic(CType, Singleton):
...
@@ -99,11 +99,11 @@ class Generic(CType, Singleton):
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
return
"""
return
"""
assert(py_
%(name)
s
->ob_refcnt > 1);
assert(py_
{name}
->ob_refcnt > 1);
Py_DECREF(py_
%(name)
s
);
Py_DECREF(py_
{name}
);
py_
%(name)
s =
%(name)
s ?
%(name)
s
: Py_None;
py_
{name} = {name} ? {name}
: Py_None;
Py_INCREF(py_
%(name)
s
);
Py_INCREF(py_
{name}
);
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,)
return
(
1
,)
...
@@ -191,17 +191,17 @@ class CDataType(CType[D]):
...
@@ -191,17 +191,17 @@ class CDataType(CType[D]):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
return
"""
return
"""
%(ctype)
s
%(name)
s
;
{ctype} {name}
;
"""
%
dict
(
ctype
=
self
.
ctype
,
name
=
name
)
"""
.
format
(
**
dict
(
ctype
=
self
.
ctype
,
name
=
name
)
)
def
c_init
(
self
,
name
,
sub
):
def
c_init
(
self
,
name
,
sub
):
return
f
"{name} = NULL;"
return
f
"{name} = NULL;"
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
return
"""
%(name)
s = (
%(ctype)
s)PyCapsule_GetPointer(py_
%(name)
s
, NULL);
{name} = ({ctype})PyCapsule_GetPointer(py_{name}
, NULL);
if (
%(name)
s == NULL)
%(fail)
s
if (
{name} == NULL) {fail}
"""
%
dict
(
name
=
name
,
ctype
=
self
.
ctype
,
fail
=
sub
[
"fail"
]
)
"""
.
format
(
**
dict
(
name
=
name
,
ctype
=
self
.
ctype
,
fail
=
sub
[
"fail"
])
)
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
freefunc
=
self
.
freefunc
freefunc
=
self
.
freefunc
...
@@ -576,39 +576,39 @@ class EnumType(CType, dict):
...
@@ -576,39 +576,39 @@ class EnumType(CType, dict):
"""
"""
return
"""
return
"""
#ifdef DEBUG
#ifdef DEBUG
int pytensor_enum_to_string_
%(cname)
s(
%(ctype)
s in, char* out)
{
int pytensor_enum_to_string_
{cname}({ctype} in, char* out) {
{
int ret = 0;
int ret = 0;
switch(in) {
switch(in) {
{
%(cases)
s
{cases}
default:
default:
PyErr_SetString(PyExc_ValueError, "
%(classname)
s
: unknown enum value.");
PyErr_SetString(PyExc_ValueError, "
{classname}
: unknown enum value.");
ret = -1;
ret = -1;
break;
break;
}
}
}
return ret;
return ret;
}
}
}
#endif
#endif
"""
%
dict
(
"""
.
format
(
cname
=
self
.
cname
,
**
dict
(
ctype
=
self
.
ctype
,
cname
=
self
.
cname
,
classname
=
type
(
self
)
.
__name__
,
ctype
=
self
.
ctype
,
cases
=
""
.
join
(
classname
=
type
(
self
)
.
__name__
,
"""
cases
=
""
.
join
(
case
%(name)
s: sprintf(out, "
%(name)
s"); break;
"""
"""
case {name}: sprintf(out, "{name}"); break;
%
dict
(
name
=
name
)
"""
.
format
(
**
dict
(
name
=
name
))
for
name
in
self
for
name
in
self
),
),
)
)
)
def
c_support_code
(
self
,
**
kwargs
):
def
c_support_code
(
self
,
**
kwargs
):
return
(
return
(
self
.
pyint_compat_code
self
.
pyint_compat_code
+
""
.
join
(
+
""
.
join
(
"""
f
"""
#define
%
s
%
s
#define
{k} {str(self[k])}
"""
"""
%
(
k
,
str
(
self
[
k
]))
for
k
in
sorted
(
self
.
keys
())
for
k
in
sorted
(
self
.
keys
())
)
)
+
self
.
c_to_string
()
+
self
.
c_to_string
()
...
@@ -625,15 +625,15 @@ class EnumType(CType, dict):
...
@@ -625,15 +625,15 @@ class EnumType(CType, dict):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
return
"""
if (PyInt_Check(py_
%(name)
s))
{
if (PyInt_Check(py_
{name})) {
{
%(name)
s = (
%(ctype)
s)PyInt_AsLong(py_
%(name)
s
);
{name} = ({ctype})PyInt_AsLong(py_{name}
);
}
else
{
}
} else {
{
%(name)
s = (
%(ctype)
s)PyFloat_AsDouble(py_
%(name)
s
);
{name} = ({ctype})PyFloat_AsDouble(py_{name}
);
}
}
}
if (PyErr_Occurred()) {
if (PyErr_Occurred()) {
{
%(fail)
s
{fail}
}
}
}
"""
%
dict
(
ctype
=
self
.
ctype
,
name
=
name
,
fail
=
sub
[
"fail"
]
)
"""
.
format
(
**
dict
(
ctype
=
self
.
ctype
,
name
=
name
,
fail
=
sub
[
"fail"
])
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
2
,
self
.
ctype
,
self
.
cname
,
tuple
(
self
.
items
()))
return
(
2
,
self
.
ctype
,
self
.
cname
,
tuple
(
self
.
items
()))
...
@@ -754,23 +754,25 @@ class CEnumType(EnumList):
...
@@ -754,23 +754,25 @@ class CEnumType(EnumList):
# swapped_dict's keys are integers.
# swapped_dict's keys are integers.
return
"""
return
"""
switch(PyInt_AsLong(py_
%(name)
s))
{
switch(PyInt_AsLong(py_
{name})) {
{
%(cases)
s
{cases}
default:
default:
PyErr_SetString(PyExc_ValueError, "CEnumType: invalid value to map to C constants.");
PyErr_SetString(PyExc_ValueError, "CEnumType: invalid value to map to C constants.");
{
%(fail)
s
}
{
{{fail}}
}
break;
break;
}
}}
"""
%
dict
(
"""
.
format
(
name
=
name
,
**
dict
(
cases
=
""
.
join
(
name
=
name
,
"""
cases
=
""
.
join
(
"""
case
%(i)
d:
%(name)
s =
%(constant_cname)
s; break;
case
%(i)
d:
%(name)
s =
%(constant_cname)
s; break;
"""
"""
%
dict
(
i
=
i
,
name
=
name
,
constant_cname
=
swapped_dict
[
i
])
%
dict
(
i
=
i
,
name
=
name
,
constant_cname
=
swapped_dict
[
i
])
for
i
in
sorted
(
swapped_dict
.
keys
())
for
i
in
sorted
(
swapped_dict
.
keys
())
),
),
fail
=
sub
[
"fail"
],
fail
=
sub
[
"fail"
],
)
)
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
...
...
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
06c5acdf
...
@@ -141,7 +141,7 @@ def _check_scipy_linalg_matrix(a, func_name):
...
@@ -141,7 +141,7 @@ def _check_scipy_linalg_matrix(a, func_name):
if
isinstance
(
a
,
types
.
Optional
):
if
isinstance
(
a
,
types
.
Optional
):
a
=
a
.
type
a
=
a
.
type
if
not
isinstance
(
a
,
types
.
Array
):
if
not
isinstance
(
a
,
types
.
Array
):
msg
=
"
%
s.
%
s() only supported for array types"
%
interp
msg
=
"
{}.{}() only supported for array types"
.
format
(
*
interp
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
a
.
ndim
not
in
[
1
,
2
]:
if
a
.
ndim
not
in
[
1
,
2
]:
msg
=
"
%
s.
%
s() only supported on 1d or 2d arrays, found
%
s."
%
(
msg
=
"
%
s.
%
s() only supported on 1d or 2d arrays, found
%
s."
%
(
...
@@ -149,7 +149,7 @@ def _check_scipy_linalg_matrix(a, func_name):
...
@@ -149,7 +149,7 @@ def _check_scipy_linalg_matrix(a, func_name):
)
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
not
isinstance
(
a
.
dtype
,
(
types
.
Float
,
types
.
Complex
)):
if
not
isinstance
(
a
.
dtype
,
(
types
.
Float
,
types
.
Complex
)):
msg
=
"
%
s.
%
s() only supported on "
"float and complex arrays."
%
interp
msg
=
"
{}.{}() only supported on "
"float and complex arrays."
.
format
(
*
interp
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
...
...
pytensor/printing.py
浏览文件 @
06c5acdf
...
@@ -646,8 +646,7 @@ def _debugprint(
...
@@ -646,8 +646,7 @@ def _debugprint(
tot_time_percent
=
(
tot_time_dict
[
node
]
/
profile
.
fct_call_time
)
*
100
tot_time_percent
=
(
tot_time_dict
[
node
]
/
profile
.
fct_call_time
)
*
100
print
(
print
(
"
%
s -->
%8.2
es
%4.1
f
%% %8.2
es
%4.1
f
%%
"
"{} --> {:8.2e}s {:4.1f}
%
{:8.2e}s {:4.1f}
%
"
.
format
(
%
(
var_output
,
var_output
,
op_time
,
op_time
,
op_time_percent
,
op_time_percent
,
...
...
pytensor/scalar/basic.py
浏览文件 @
06c5acdf
...
@@ -466,40 +466,44 @@ class ScalarType(CType, HasDataType, HasShape):
...
@@ -466,40 +466,44 @@ class ScalarType(CType, HasDataType, HasShape):
specs
=
self
.
dtype_specs
()
specs
=
self
.
dtype_specs
()
if
check_input
:
if
check_input
:
pre
=
"""
pre
=
"""
if (!PyObject_TypeCheck(py_
%(name)
s, &
%(pyarr_type)
s
))
if (!PyObject_TypeCheck(py_
{name}, &{pyarr_type}
))
{
{
{
PyErr_Format(PyExc_ValueError,
PyErr_Format(PyExc_ValueError,
"Scalar check failed (
%(dtype)
s)");
"Scalar check failed ({dtype})");
%(fail)
s
{fail}
}
}}
"""
%
dict
(
"""
.
format
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
pyarr_type
=
"Py
%
sArrType_Type"
%
specs
[
2
]
**
dict
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
pyarr_type
=
"Py
%
sArrType_Type"
%
specs
[
2
],
)
)
)
else
:
else
:
pre
=
""
pre
=
""
return
(
return
(
pre
pre
+
"""
+
"""
PyArray_ScalarAsCtype(py_
%(name)
s, &
%(name)
s);
PyArray_ScalarAsCtype(py_{name}, &{name});
"""
"""
.
format
(
**
dict
(
sub
,
name
=
name
))
%
dict
(
sub
,
name
=
name
)
)
)
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
specs
=
self
.
dtype_specs
()
specs
=
self
.
dtype_specs
()
return
"""
return
"""
Py_XDECREF(py_
%(name)
s
);
Py_XDECREF(py_
{name}
);
py_
%(name)
s = PyArrayScalar_New(
%(cls)
s
);
py_
{name} = PyArrayScalar_New({cls}
);
if (!py_
%(name)
s
)
if (!py_
{name}
)
{
{
{
Py_XINCREF(Py_None);
Py_XINCREF(Py_None);
py_
%(name)
s
= Py_None;
py_
{name}
= Py_None;
PyErr_Format(PyExc_MemoryError,
PyErr_Format(PyExc_MemoryError,
"Instantiation of new Python scalar failed (
%(dtype)
s
)");
"Instantiation of new Python scalar failed (
{dtype}
)");
%(fail)
s
{fail}
}
}
}
PyArrayScalar_ASSIGN(py_
%(name)
s,
%(cls)
s,
%(name)
s
);
PyArrayScalar_ASSIGN(py_
{name}, {cls}, {name}
);
"""
%
dict
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
cls
=
specs
[
2
]
)
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
cls
=
specs
[
2
])
)
def
c_cleanup
(
self
,
name
,
sub
):
def
c_cleanup
(
self
,
name
,
sub
):
return
""
return
""
...
...
pytensor/scalar/math.py
浏览文件 @
06c5acdf
...
@@ -620,8 +620,8 @@ class Chi2SF(BinaryScalarOp):
...
@@ -620,8 +620,8 @@ class Chi2SF(BinaryScalarOp):
(
z
,)
=
out
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""
%(z)
s
=
return
"""
{z}
=
(
%(dtype)
s) 1 - GammaP(
%(k)
s/2.,
%(x)
s/2.);"""
%
locals
(
)
(
{dtype}) 1 - GammaP({k}/2., {x}/2.);"""
.
format
(
**
locals
()
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp):
...
@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp):
(
z
,)
=
out
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""
%(z)
s
=
return
"""
{z}
=
(
%(dtype)
s) GammaP(
%(k)
s,
%(x)
s);"""
%
locals
(
)
(
{dtype}) GammaP({k}, {x});"""
.
format
(
**
locals
()
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -712,8 +712,8 @@ class GammaIncC(BinaryScalarOp):
...
@@ -712,8 +712,8 @@ class GammaIncC(BinaryScalarOp):
(
z
,)
=
out
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""
%(z)
s
=
return
"""
{z}
=
(
%(dtype)
s) GammaQ(
%(k)
s,
%(x)
s);"""
%
locals
(
)
(
{dtype}) GammaQ({k}, {x});"""
.
format
(
**
locals
()
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -1018,8 +1018,8 @@ class GammaU(BinaryScalarOp):
...
@@ -1018,8 +1018,8 @@ class GammaU(BinaryScalarOp):
(
z
,)
=
out
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""
%(z)
s
=
return
"""
{z}
=
(
%(dtype)
s) upperGamma(
%(k)
s,
%(x)
s);"""
%
locals
(
)
(
{dtype}) upperGamma({k}, {x});"""
.
format
(
**
locals
()
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -1056,8 +1056,8 @@ class GammaL(BinaryScalarOp):
...
@@ -1056,8 +1056,8 @@ class GammaL(BinaryScalarOp):
(
z
,)
=
out
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""
%(z)
s
=
return
"""
{z}
=
(
%(dtype)
s) lowerGamma(
%(k)
s,
%(x)
s);"""
%
locals
(
)
(
{dtype}) lowerGamma({k}, {x});"""
.
format
(
**
locals
()
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
...
pytensor/scan/op.py
浏览文件 @
06c5acdf
...
@@ -3364,8 +3364,7 @@ def profile_printer(
...
@@ -3364,8 +3364,7 @@ def profile_printer(
total_scan_fct_time
+=
scan_fct_time
total_scan_fct_time
+=
scan_fct_time
total_scan_op_time
+=
scan_op_time
total_scan_op_time
+=
scan_op_time
print
(
print
(
"
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
"
" {:5.1f}s {:5.1f}s {:5.1f}s {:5.1f}
%
{:5.1f}
%
"
.
format
(
%
(
v
,
v
,
scan_fct_time
,
scan_fct_time
,
scan_op_time
,
scan_op_time
,
...
@@ -3385,8 +3384,7 @@ def profile_printer(
...
@@ -3385,8 +3384,7 @@ def profile_printer(
print
(
" No scan have its inner profile enabled."
,
file
=
file
)
print
(
" No scan have its inner profile enabled."
,
file
=
file
)
else
:
else
:
print
(
print
(
"total
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
"
"total {:5.1f}s {:5.1f}s {:5.1f}s {:5.1f}
%
{:5.1f}
%
"
.
format
(
%
(
total_super_scan_time
,
total_super_scan_time
,
total_scan_fct_time
,
total_scan_fct_time
,
total_scan_op_time
,
total_scan_op_time
,
...
...
pytensor/sparse/basic.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
pytensor/sparse/rewriting.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
pytensor/tensor/basic.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
pytensor/tensor/blas.py
浏览文件 @
06c5acdf
...
@@ -1840,19 +1840,19 @@ class BatchedDot(COp):
...
@@ -1840,19 +1840,19 @@ class BatchedDot(COp):
z_shape
=
", "
.
join
(
z_dims
)
z_shape
=
", "
.
join
(
z_dims
)
z_contiguous
=
contiguous
(
_z
,
z_ndim
)
z_contiguous
=
contiguous
(
_z
,
z_ndim
)
allocate
=
"""
allocate
=
"""
if (NULL ==
%(_z)
s || !(
%(z_shape_correct)
s) || !(
%(z_contiguous)
s
))
if (NULL ==
{_z} || !({z_shape_correct}) || !({z_contiguous}
))
{
{
{
npy_intp dims[
%(z_ndim)
s] = {
%(z_shape)
s
};
npy_intp dims[
{z_ndim}] = {{{z_shape}}
};
Py_XDECREF(
%(_z)
s
);
Py_XDECREF(
{_z}
);
%(_z)
s
= (PyArrayObject*)PyArray_SimpleNew(
{_z}
= (PyArrayObject*)PyArray_SimpleNew(
%(z_ndim)
s, dims, PyArray_TYPE(
%(_x)
s
));
{z_ndim}, dims, PyArray_TYPE({_x}
));
if(!
%(_z)
s)
{
if(!
{_z}) {
{
PyErr_SetString(PyExc_MemoryError,
PyErr_SetString(PyExc_MemoryError,
"failed to alloc BatchedDot output");
"failed to alloc BatchedDot output");
%(fail)
s
{fail}
}
}
}
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
# code to reallocate inputs contiguously if necessary
# code to reallocate inputs contiguously if necessary
contiguate
=
[]
contiguate
=
[]
...
@@ -1860,76 +1860,75 @@ class BatchedDot(COp):
...
@@ -1860,76 +1860,75 @@ class BatchedDot(COp):
_contiguous
=
contiguous
(
var
,
ndim
)
_contiguous
=
contiguous
(
var
,
ndim
)
contiguate
.
append
(
contiguate
.
append
(
"""
"""
if (!(
%(_contiguous)
s))
{
if (!(
{_contiguous})) {
{
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(
%(var)
s
);
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(
{var}
);
if (!_copy)
if (!_copy)
%(fail)
s
{fail}
Py_XDECREF(
%(var)
s);
Py_XDECREF({var});
%(var)
s = _copy;
{var} = _copy;
}
}}
"""
"""
.
format
(
**
locals
())
%
locals
()
)
)
contiguate
=
"
\n
"
.
join
(
contiguate
)
contiguate
=
"
\n
"
.
join
(
contiguate
)
return
"""
return
"""
int type_num = PyArray_DESCR(
%(_x)
s
)->type_num;
int type_num = PyArray_DESCR(
{_x}
)->type_num;
int type_size = PyArray_DESCR(
%(_x)
s
)->elsize; // in bytes
int type_size = PyArray_DESCR(
{_x}
)->elsize; // in bytes
if (PyArray_NDIM(
%(_x)
s) != 3)
{
if (PyArray_NDIM(
{_x}) != 3) {
{
PyErr_Format(PyExc_NotImplementedError,
PyErr_Format(PyExc_NotImplementedError,
"rank(x) != 3. rank(x) is
%
%
d.",
"rank(x) != 3. rank(x) is
%
d.",
PyArray_NDIM(
%(_x)
s
));
PyArray_NDIM(
{_x}
));
%(fail)
s
;
{fail}
;
}
}
}
if (PyArray_NDIM(
%(_y)
s) != 3)
{
if (PyArray_NDIM(
{_y}) != 3) {
{
PyErr_Format(PyExc_NotImplementedError,
PyErr_Format(PyExc_NotImplementedError,
"rank(y) != 3. rank(y) is
%
%
d.",
"rank(y) != 3. rank(y) is
%
d.",
PyArray_NDIM(
%(_y)
s
));
PyArray_NDIM(
{_y}
));
%(fail)
s
;
{fail}
;
}
}
}
if (
%(_z)
s && PyArray_NDIM(
%(_z)
s) != 3)
{
if (
{_z} && PyArray_NDIM({_z}) != 3) {
{
PyErr_Format(PyExc_NotImplementedError,
PyErr_Format(PyExc_NotImplementedError,
"rank(z) != 3. rank(z) is
%
%
d.",
"rank(z) != 3. rank(z) is
%
d.",
PyArray_NDIM(
%(_z)
s
));
PyArray_NDIM(
{_z}
));
%(fail)
s
;
{fail}
;
}
}
}
// allocate output
// allocate output
%(allocate)
s
{allocate}
// reallocate any noncontiguous arrays or arrays with invalid strides
// reallocate any noncontiguous arrays or arrays with invalid strides
%(contiguate)
s
{contiguate}
if ((PyArray_DESCR(
%(_x)
s
)->type_num != NPY_DOUBLE)
if ((PyArray_DESCR(
{_x}
)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
%(_x)
s
)->type_num != NPY_FLOAT))
&& (PyArray_DESCR(
{_x}
)->type_num != NPY_FLOAT))
{
PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float");
%(fail)
s;
}
{
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); {fail};}
}
if ((PyArray_DESCR(
%(_y)
s
)->type_num != NPY_DOUBLE)
if ((PyArray_DESCR(
{_y}
)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
%(_y)
s
)->type_num != NPY_FLOAT))
&& (PyArray_DESCR(
{_y}
)->type_num != NPY_FLOAT))
{
PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float");
%(fail)
s;
}
{
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); {fail};}
}
if ((PyArray_DESCR(
%(_z)
s
)->type_num != NPY_DOUBLE)
if ((PyArray_DESCR(
{_z}
)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
%(_z)
s
)->type_num != NPY_FLOAT))
&& (PyArray_DESCR(
{_z}
)->type_num != NPY_FLOAT))
{
PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float");
%(fail)
s;
}
{
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); {fail};}
}
if ((PyArray_DESCR(
%(_x)
s)->type_num != PyArray_DESCR(
%(_y)
s
)->type_num)
if ((PyArray_DESCR(
{_x})->type_num != PyArray_DESCR({_y}
)->type_num)
||(PyArray_DESCR(
%(_x)
s)->type_num != PyArray_DESCR(
%(_z)
s
)->type_num))
||(PyArray_DESCR(
{_x})->type_num != PyArray_DESCR({_z}
)->type_num))
{
PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same");
%(fail)
s;
}
{
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); {fail}; }
}
switch (type_num)
switch (type_num)
{
{
{
case NPY_FLOAT:
case NPY_FLOAT:
if (batch_gemm<float>(sgemm_, type_size,
%(_x)
s,
%(_y)
s,
%(_z)
s))
{
if (batch_gemm<float>(sgemm_, type_size,
{_x}, {_y}, {_z})) {
{
%(fail)
s
;
{fail}
;
}
}
}
break;
break;
case NPY_DOUBLE:
case NPY_DOUBLE:
if (batch_gemm<double>(dgemm_, type_size,
%(_x)
s,
%(_y)
s,
%(_z)
s))
{
if (batch_gemm<double>(dgemm_, type_size,
{_x}, {_y}, {_z})) {
{
%(fail)
s
;
{fail}
;
}
}
}
break;
break;
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
from
pytensor.tensor.blas_headers
import
blas_header_version
from
pytensor.tensor.blas_headers
import
blas_header_version
...
...
pytensor/tensor/blas_c.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
pytensor/tensor/blas_headers.py
浏览文件 @
06c5acdf
...
@@ -1097,7 +1097,7 @@ def ____gemm_code(check_ab, a_init, b_init):
...
@@ -1097,7 +1097,7 @@ def ____gemm_code(check_ab, a_init, b_init):
if (PyArray_NDIM(_y) != 2) goto _dot_execute_fallback;
if (PyArray_NDIM(_y) != 2) goto _dot_execute_fallback;
if (PyArray_NDIM(_z) != 2) goto _dot_execute_fallback;
if (PyArray_NDIM(_z) != 2) goto _dot_execute_fallback;
%(check_ab)
s
{check_ab}
if ((PyArray_DESCR(_x)->type_num != NPY_DOUBLE)
if ((PyArray_DESCR(_x)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(_x)->type_num != NPY_FLOAT))
&& (PyArray_DESCR(_x)->type_num != NPY_FLOAT))
...
@@ -1117,16 +1117,16 @@ def ____gemm_code(check_ab, a_init, b_init):
...
@@ -1117,16 +1117,16 @@ def ____gemm_code(check_ab, a_init, b_init):
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
{
{
error_string = "Input dimensions do not agree";
error_string = "Input dimensions do not agree";
goto _dot_execute_fail;
goto _dot_execute_fail;
}
}
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0]
%(mod)
s type_size) || (Sx[1]
%(mod)
s
type_size)
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0]
{mod} type_size) || (Sx[1] {mod}
type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0]
%(mod)
s type_size) || (Sy[1]
%(mod)
s
type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0]
{mod} type_size) || (Sy[1] {mod}
type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0]
%(mod)
s type_size) || (Sz[1]
%(mod)
s
type_size))
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0]
{mod} type_size) || (Sz[1] {mod}
type_size))
{
{
{
goto _dot_execute_fallback;
goto _dot_execute_fallback;
}
}
}
/*
/*
encode the stride structure of _x,_y,_z into a single integer
encode the stride structure of _x,_y,_z into a single integer
...
@@ -1146,19 +1146,19 @@ def ____gemm_code(check_ab, a_init, b_init):
...
@@ -1146,19 +1146,19 @@ def ____gemm_code(check_ab, a_init, b_init):
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
switch (type_num)
{
{
{
case NPY_FLOAT:
case NPY_FLOAT:
{
{
{
#define REAL float
#define REAL float
float a =
%(a_init)
s
;
float a =
{a_init}
;
float b =
%(b_init)
s
;
float b =
{b_init}
;
float* x = (float*)PyArray_DATA(_x);
float* x = (float*)PyArray_DATA(_x);
float* y = (float*)PyArray_DATA(_y);
float* y = (float*)PyArray_DATA(_y);
float* z = (float*)PyArray_DATA(_z);
float* z = (float*)PyArray_DATA(_z);
switch(unit)
switch(unit)
{
{
{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
...
@@ -1168,21 +1168,21 @@ def ____gemm_code(check_ab, a_init, b_init):
...
@@ -1168,21 +1168,21 @@ def ____gemm_code(check_ab, a_init, b_init):
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
default: goto _dot_execute_fallback;
};
}
}
;
#undef REAL
#undef REAL
}
}
}
break;
break;
case NPY_DOUBLE:
case NPY_DOUBLE:
{
{
{
#define REAL double
#define REAL double
double a =
%(a_init)
s
;
double a =
{a_init}
;
double b =
%(b_init)
s
;
double b =
{b_init}
;
double* x = (double*)PyArray_DATA(_x);
double* x = (double*)PyArray_DATA(_x);
double* y = (double*)PyArray_DATA(_y);
double* y = (double*)PyArray_DATA(_y);
double* z = (double*)PyArray_DATA(_z);
double* z = (double*)PyArray_DATA(_z);
switch(unit)
switch(unit)
{
{
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
...
@@ -1192,11 +1192,11 @@ def ____gemm_code(check_ab, a_init, b_init):
...
@@ -1192,11 +1192,11 @@ def ____gemm_code(check_ab, a_init, b_init):
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
default: goto _dot_execute_fallback;
};
}
}
;
#undef REAL
#undef REAL
}
}
}
break;
break;
}
}
}
return 0; //success!
return 0; //success!
...
@@ -1212,4 +1212,4 @@ def ____gemm_code(check_ab, a_init, b_init):
...
@@ -1212,4 +1212,4 @@ def ____gemm_code(check_ab, a_init, b_init):
return -1;
return -1;
/* v 1 */
/* v 1 */
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
pytensor/tensor/elemwise.py
浏览文件 @
06c5acdf
...
@@ -929,12 +929,12 @@ class Elemwise(OpenMPOp):
...
@@ -929,12 +929,12 @@ class Elemwise(OpenMPOp):
# decrease the reference of whatever the output contained
# decrease the reference of whatever the output contained
# prior to this
# prior to this
alloc
+=
"""
alloc
+=
"""
if (
%(oname)
s)
{
if (
{oname}) {
{
Py_XDECREF(
%(oname)
s
);
Py_XDECREF(
{oname}
);
}
}
}
%(oname)
s =
%(iname)
s
;
{oname} = {iname}
;
Py_XINCREF(
%(oname)
s
);
Py_XINCREF(
{oname}
);
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
# We alias the scalar variables
# We alias the scalar variables
defines
+=
f
"#define {oname}_i {iname}_i
\n
"
defines
+=
f
"#define {oname}_i {iname}_i
\n
"
undefs
+=
f
"#undef {oname}_i
\n
"
undefs
+=
f
"#undef {oname}_i
\n
"
...
@@ -958,12 +958,12 @@ class Elemwise(OpenMPOp):
...
@@ -958,12 +958,12 @@ class Elemwise(OpenMPOp):
dict
(
sub
,
fail
=
fail
),
dict
(
sub
,
fail
=
fail
),
)
)
code
=
"""
code
=
"""
{
{
{
%(defines)
s
{defines}
%(task_code)
s
{task_code}
%(undefs)
s
{undefs}
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
loop_orders
=
orders
+
[
list
(
range
(
nnested
))]
*
len
(
real_onames
)
loop_orders
=
orders
+
[
list
(
range
(
nnested
))]
*
len
(
real_onames
)
dtypes
=
idtypes
+
list
(
real_odtypes
)
dtypes
=
idtypes
+
list
(
real_odtypes
)
...
@@ -995,27 +995,27 @@ class Elemwise(OpenMPOp):
...
@@ -995,27 +995,27 @@ class Elemwise(OpenMPOp):
if
index
!=
"x"
:
if
index
!=
"x"
:
preloops
.
setdefault
(
j
,
""
)
preloops
.
setdefault
(
j
,
""
)
preloops
[
j
]
+=
(
preloops
[
j
]
+=
(
"
%
%
(lv
%(i)
s)s_iter = (
%(dtype)
s
*)"
"
%
(lv{i})s_iter = ({dtype}
*)"
"(PyArray_DATA(
%
%
(lv
%(i)
s)s));
\n
"
%
locals
(
)
"(PyArray_DATA(
%
(lv{i})s));
\n
"
.
format
(
**
locals
()
)
)
%
sub
)
%
sub
break
break
else
:
# all broadcastable
else
:
# all broadcastable
preloops
.
setdefault
(
0
,
""
)
preloops
.
setdefault
(
0
,
""
)
preloops
[
0
]
+=
(
preloops
[
0
]
+=
(
"
%
%
(lv
%(i)
s)s_iter = (
%(dtype)
s
*)"
"
%
(lv{i})s_iter = ({dtype}
*)"
"(PyArray_DATA(
%
%
(lv
%(i)
s)s));
\n
"
%
locals
(
)
"(PyArray_DATA(
%
(lv{i})s));
\n
"
.
format
(
**
locals
()
)
)
%
sub
)
%
sub
init_array
=
preloops
.
get
(
0
,
" "
)
init_array
=
preloops
.
get
(
0
,
" "
)
loop
=
"""
loop
=
"""
{
{
{
%(defines)
s
{defines}
%(init_array)
s
{init_array}
%(task_decl)
s
{task_decl}
%(task_code)
s
{task_code}
%(undefs)
s
{undefs}
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
else
:
else
:
loop
=
cgen
.
make_loop
(
loop
=
cgen
.
make_loop
(
loop_orders
=
loop_orders
,
loop_orders
=
loop_orders
,
...
@@ -1076,24 +1076,24 @@ class Elemwise(OpenMPOp):
...
@@ -1076,24 +1076,24 @@ class Elemwise(OpenMPOp):
for
x
,
var
in
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
):
for
x
,
var
in
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
):
if
not
all
(
s
==
1
for
s
in
var
.
type
.
shape
):
if
not
all
(
s
==
1
for
s
in
var
.
type
.
shape
):
contig
+=
"""
contig
+=
"""
dtype_
%(x)
s *
%(x)
s_ptr = (dtype_
%(x)
s*) PyArray_DATA(
%(x)
s
);
dtype_
{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x}
);
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
index
+=
"""
index
+=
"""
dtype_
%(x)
s&
%(x)
s_i =
%(x)
s
_ptr[i];
dtype_
{x}& {x}_i = {x}
_ptr[i];
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
else
:
else
:
contig
+=
"""
contig
+=
"""
dtype_
%(x)
s&
%(x)
s_i = ((dtype_
%(x)
s*) PyArray_DATA(
%(x)
s
))[0];
dtype_
{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}
))[0];
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
if
self
.
openmp
:
if
self
.
openmp
:
contig
+=
f
"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)})
contig
+=
f
"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)})
"""
"""
contig
+=
"""
contig
+=
"""
for(int i=0; i<n; i++){
for(int i=0; i<n; i++){
{
%(index)
s
{index}
%(task_code)
s
;
{task_code}
;
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
if
contig
is
not
None
:
if
contig
is
not
None
:
z
=
list
(
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
))
z
=
list
(
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
))
all_broadcastable
=
all
(
s
==
1
for
s
in
var
.
type
.
shape
)
all_broadcastable
=
all
(
s
==
1
for
s
in
var
.
type
.
shape
)
...
@@ -1112,12 +1112,12 @@ class Elemwise(OpenMPOp):
...
@@ -1112,12 +1112,12 @@ class Elemwise(OpenMPOp):
]
]
)
)
loop
=
"""
loop
=
"""
if((
%(cond1)
s) || (
%(cond2)
s))
{
if((
{cond1}) || ({cond2})){
{
%(contig)
s
{contig}
}
else
{
}
}else{
{
%(loop)
s
{loop}
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
return
decl
,
checks
,
alloc
,
loop
,
""
return
decl
,
checks
,
alloc
,
loop
,
""
def
c_code
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
...
...
pytensor/tensor/elemwise_cgen.py
浏览文件 @
06c5acdf
...
@@ -176,34 +176,34 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
...
@@ -176,34 +176,34 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
# contiguous dimensions, or the dimension with the smallest
# contiguous dimensions, or the dimension with the smallest
# stride. Right now, it is allocated to be C_CONTIGUOUS.
# stride. Right now, it is allocated to be C_CONTIGUOUS.
return
"""
return
"""
{
{
{
npy_intp dims[
%(nd)
s
];
npy_intp dims[
{nd}
];
//npy_intp* dims = (npy_intp*)malloc(
%(nd)
s
* sizeof(npy_intp));
//npy_intp* dims = (npy_intp*)malloc(
{nd}
* sizeof(npy_intp));
%(init_dims)
s
{init_dims}
if (!
%(olv)
s)
{
if (!
{olv}) {
{
%(olv)
s = (PyArrayObject*)PyArray_EMPTY(
%(nd)
s
, dims,
{olv} = (PyArrayObject*)PyArray_EMPTY({nd}
, dims,
%(type)
s
,
{type}
,
%(fortran)
s
);
{fortran}
);
}
}
}
else {
else {
{
PyArray_Dims new_dims;
PyArray_Dims new_dims;
new_dims.len =
%(nd)
s
;
new_dims.len =
{nd}
;
new_dims.ptr = dims;
new_dims.ptr = dims;
PyObject* success = PyArray_Resize(
%(olv)
s
, &new_dims, 0, NPY_CORDER);
PyObject* success = PyArray_Resize(
{olv}
, &new_dims, 0, NPY_CORDER);
if (!success) {
if (!success) {
{
// If we can't resize the ndarray we have we can allocate a new one.
// If we can't resize the ndarray we have we can allocate a new one.
PyErr_Clear();
PyErr_Clear();
Py_XDECREF(
%(olv)
s
);
Py_XDECREF(
{olv}
);
%(olv)
s = (PyArrayObject*)PyArray_EMPTY(
%(nd)
s, dims,
%(type)
s
, 0);
{olv} = (PyArrayObject*)PyArray_EMPTY({nd}, dims, {type}
, 0);
}
else
{
}
} else {
{
Py_DECREF(success);
Py_DECREF(success);
}
}
}
}
}
}
if (!
%(olv)
s)
{
if (!
{olv}) {
{
%(fail)
s
{fail}
}
}
}
}
}
}
"""
%
dict
(
locals
(),
**
sub
)
"""
.
format
(
**
dict
(
locals
(),
**
sub
)
)
def
make_loop
(
loop_orders
,
dtypes
,
loop_tasks
,
sub
,
openmp
=
None
):
def
make_loop
(
loop_orders
,
dtypes
,
loop_tasks
,
sub
,
openmp
=
None
):
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
06c5acdf
...
@@ -69,24 +69,24 @@ class CpuContiguous(COp):
...
@@ -69,24 +69,24 @@ class CpuContiguous(COp):
(
x
,)
=
inames
(
x
,)
=
inames
(
y
,)
=
onames
(
y
,)
=
onames
code
=
"""
code
=
"""
if (!PyArray_CHKFLAGS(
%(x)
s, NPY_ARRAY_C_CONTIGUOUS))
{
if (!PyArray_CHKFLAGS(
{x}, NPY_ARRAY_C_CONTIGUOUS)){
{
// check to see if output is contiguous first
// check to see if output is contiguous first
if (
%(y)
s
!= NULL &&
if (
{y}
!= NULL &&
PyArray_CompareLists(PyArray_DIMS(
%(y)
s), PyArray_DIMS(
%(x)
s), PyArray_NDIM(
%(x)
s
)) &&
PyArray_CompareLists(PyArray_DIMS(
{y}), PyArray_DIMS({x}), PyArray_NDIM({x}
)) &&
PyArray_CHKFLAGS(
%(y)
s, NPY_ARRAY_C_CONTIGUOUS))
{
PyArray_CHKFLAGS(
{y}, NPY_ARRAY_C_CONTIGUOUS)){
{
PyArray_CopyInto(
%(y)
s,
%(x)
s
);
PyArray_CopyInto(
{y}, {x}
);
}
}
}
else{
else{
{
Py_XDECREF(
%(y)
s
);
Py_XDECREF(
{y}
);
%(y)
s = PyArray_GETCONTIGUOUS(
%(x)
s
);
{y} = PyArray_GETCONTIGUOUS({x}
);
}
}
}
}
}
}
else{
else{
{
Py_XINCREF(
%(x)
s
);
Py_XINCREF(
{x}
);
Py_XDECREF(
%(y)
s
);
Py_XDECREF(
{y}
);
%(y)
s =
%(x)
s
;
{y} = {x}
;
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
return
code
return
code
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
...
@@ -162,12 +162,12 @@ class SearchsortedOp(COp):
...
@@ -162,12 +162,12 @@ class SearchsortedOp(COp):
side
=
sub
[
"params"
]
side
=
sub
[
"params"
]
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
return
"""
return
"""
PyObject* tmp_
%(name)
s
= PyUnicode_FromString("right");
PyObject* tmp_
{name}
= PyUnicode_FromString("right");
if (tmp_
%(name)
s
== NULL)
if (tmp_
{name}
== NULL)
%(fail)
s
;
{fail}
;
right_
%(name)
s = PyUnicode_Compare(
%(side)
s, tmp_
%(name)
s
);
right_
{name} = PyUnicode_Compare({side}, tmp_{name}
);
Py_DECREF(tmp_
%(name)
s
);
Py_DECREF(tmp_
{name}
);
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
sorter
=
None
sorter
=
None
...
@@ -181,17 +181,17 @@ class SearchsortedOp(COp):
...
@@ -181,17 +181,17 @@ class SearchsortedOp(COp):
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
return
"""
return
"""
Py_XDECREF(
%(z)
s
);
Py_XDECREF(
{z}
);
%(z)
s = (PyArrayObject*) PyArray_SearchSorted(
%(x)
s, (PyObject*)
%(v)
s
,
{z} = (PyArrayObject*) PyArray_SearchSorted({x}, (PyObject*) {v}
,
right_
%(name)
s ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*)
%(sorter)
s
);
right_
{name} ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) {sorter}
);
if (!
%(z)
s
)
if (!
{z}
)
%(fail)
s
;
{fail}
;
if (PyArray_TYPE(
%(z)
s) != NPY_INT64)
{
if (PyArray_TYPE(
{z}) != NPY_INT64){
{
PyObject * tmp = PyArray_Cast(
%(z)
s
, NPY_INT64);
PyObject * tmp = PyArray_Cast(
{z}
, NPY_INT64);
Py_XDECREF(
%(z)
s
);
Py_XDECREF(
{z}
);
%(z)
s
= (PyArrayObject*) tmp;
{z}
= (PyArrayObject*) tmp;
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
2
,)
return
(
2
,)
...
@@ -351,43 +351,43 @@ class CumOp(COp):
...
@@ -351,43 +351,43 @@ class CumOp(COp):
params
=
sub
[
"params"
]
params
=
sub
[
"params"
]
code
=
"""
code
=
"""
int axis =
%(params)
s
->c_axis;
int axis =
{params}
->c_axis;
if (axis == 0 && PyArray_NDIM(
%(x)
s
) == 1)
if (axis == 0 && PyArray_NDIM(
{x}
) == 1)
axis = NPY_MAXDIMS;
axis = NPY_MAXDIMS;
npy_intp shape[1] = {
PyArray_SIZE(
%(x)
s)
};
npy_intp shape[1] = {
{ PyArray_SIZE({x}) }
};
if(axis == NPY_MAXDIMS && !(
%(z)
s && PyArray_DIMS(
%(z)
s
)[0] == shape[0]))
if(axis == NPY_MAXDIMS && !(
{z} && PyArray_DIMS({z}
)[0] == shape[0]))
{
{
{
Py_XDECREF(
%(z)
s
);
Py_XDECREF(
{z}
);
%(z)
s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_
%(x)
s
));
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_{x}
));
}
}
}
else if(axis != NPY_MAXDIMS && !(
%(z)
s && PyArray_CompareLists(PyArray_DIMS(
%(z)
s), PyArray_DIMS(
%(x)
s), PyArray_NDIM(
%(x)
s
))))
else if(axis != NPY_MAXDIMS && !(
{z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}
))))
{
{
{
Py_XDECREF(
%(z)
s
);
Py_XDECREF(
{z}
);
%(z)
s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(
%(x)
s), PyArray_DIMS(
%(x)
s), PyArray_TYPE(
%(x)
s
));
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}
));
}
}
}
if (!
%(z)
s
)
if (!
{z}
)
%(fail)
s
;
{fail}
;
{
{
{
PyObject * t = NULL;
PyObject * t = NULL;
if(
%(params)
s
->mode == MODE_ADD)
if(
{params}
->mode == MODE_ADD)
t = PyArray_CumSum(
t = PyArray_CumSum(
%(x)
s
, axis,
{x}
, axis,
PyArray_TYPE(
%(x)
s),
%(z)
s
);
PyArray_TYPE(
{x}), {z}
);
else if(
%(params)
s
->mode == MODE_MUL)
else if(
{params}
->mode == MODE_MUL)
t = PyArray_CumProd(
t = PyArray_CumProd(
%(x)
s
, axis,
{x}
, axis,
PyArray_TYPE(
%(x)
s),
%(z)
s
);
PyArray_TYPE(
{x}), {z}
);
if (!t){
if (!t){
{
%(fail)
s
;
{fail}
;
}
}
}
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
Py_XDECREF(t);
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
return
code
return
code
...
...
pytensor/tensor/math.py
浏览文件 @
06c5acdf
...
@@ -420,13 +420,13 @@ class Argmax(COp):
...
@@ -420,13 +420,13 @@ class Argmax(COp):
raise
NotImplementedError
()
raise
NotImplementedError
()
# params is only used here for now
# params is only used here for now
axis_code
=
"""
axis_code
=
"""
axis =
%(params)
s
->c_axis;
axis =
{params}
->c_axis;
if(axis > PyArray_NDIM(
%(x)
s)-1 || axis < -PyArray_NDIM(
%(x)
s))
{
if(axis > PyArray_NDIM(
{x})-1 || axis < -PyArray_NDIM({x})){
{
PyErr_SetString(PyExc_ValueError,
PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument");
"Argmax, bad axis argument");
%(fail)
s
{fail}
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
ret
=
"""
ret
=
"""
int axis;
int axis;
...
...
pytensor/tensor/subtensor.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
pytensor/tensor/type.py
浏览文件 @
06c5acdf
...
@@ -476,47 +476,47 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
...
@@ -476,47 +476,47 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
if
check_input
:
if
check_input
:
check
=
"""
check
=
"""
typedef
%(dtype)
s dtype_
%(name)
s
;
typedef
{dtype} dtype_{name}
;
"""
%
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
]
)
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
])
)
else
:
else
:
check
=
""
check
=
""
declaration
=
"""
declaration
=
"""
PyArrayObject*
%(name)
s
;
PyArrayObject*
{name}
;
"""
%
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
]
)
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
])
)
return
declaration
+
check
return
declaration
+
check
def
c_init
(
self
,
name
,
sub
):
def
c_init
(
self
,
name
,
sub
):
return
"""
return
"""
%(name)
s
= NULL;
{name}
= NULL;
"""
%
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]
)
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
])
)
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
if
check_input
:
if
check_input
:
check
=
"""
check
=
"""
%(name)
s
= NULL;
{name}
= NULL;
if (py_
%(name)
s == Py_None)
{
if (py_
{name} == Py_None) {
{
// We can either fail here or set
%(name)
s
to NULL and rely on Ops
// We can either fail here or set
{name}
to NULL and rely on Ops
// using tensors to handle the NULL case, but if they fail to do so
// using tensors to handle the NULL case, but if they fail to do so
// they'll end up with nasty segfaults, so this is public service.
// they'll end up with nasty segfaults, so this is public service.
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None");
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None");
%(fail)
s
{fail}
}
}
}
if (!PyArray_Check(py_
%(name)
s))
{
if (!PyArray_Check(py_
{name})) {
{
PyErr_SetString(PyExc_ValueError, "expected an ndarray");
PyErr_SetString(PyExc_ValueError, "expected an ndarray");
%(fail)
s
{fail}
}
}
}
// We expect
%(type_num)
s
// We expect
{type_num}
if (!PyArray_ISALIGNED((PyArrayObject*) py_
%(name)
s))
{
if (!PyArray_ISALIGNED((PyArrayObject*) py_
{name})) {
{
PyArrayObject * tmp = (PyArrayObject*) py_
%(name)
s
;
PyArrayObject * tmp = (PyArrayObject*) py_
{name}
;
PyErr_Format(PyExc_NotImplementedError,
PyErr_Format(PyExc_NotImplementedError,
"expected an aligned array of type
%
%
ld "
"expected an aligned array of type
%
ld "
"(
%(type_num)
s), got non-aligned array of type
%
%
ld"
"(
{type_num}), got non-aligned array of type
%
ld"
" with
%
%
ld dimensions, with 3 last dims "
" with
%
ld dimensions, with 3 last dims "
"
%
%
ld,
%%
ld,
%
%
ld"
"
%
ld,
%
ld,
%
ld"
" and 3 last strides
%
%
ld
%%
ld,
%
%
ld.",
" and 3 last strides
%
ld
%
ld,
%
ld.",
(long int)
%(type_num)
s
,
(long int)
{type_num}
,
(long int) PyArray_TYPE((PyArrayObject*) py_
%(name)
s
),
(long int) PyArray_TYPE((PyArrayObject*) py_
{name}
),
(long int) PyArray_NDIM(tmp),
(long int) PyArray_NDIM(tmp),
(long int) (PyArray_NDIM(tmp) >= 3 ?
(long int) (PyArray_NDIM(tmp) >= 3 ?
PyArray_DIMS(tmp)[PyArray_NDIM(tmp)-3] : -1),
PyArray_DIMS(tmp)[PyArray_NDIM(tmp)-3] : -1),
...
@@ -531,74 +531,73 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
...
@@ -531,74 +531,73 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
(long int) (PyArray_NDIM(tmp) >= 1 ?
(long int) (PyArray_NDIM(tmp) >= 1 ?
PyArray_STRIDES(tmp)[PyArray_NDIM(tmp)-1] : -1)
PyArray_STRIDES(tmp)[PyArray_NDIM(tmp)-1] : -1)
);
);
%(fail)
s
{fail}
}
}
}
// This is a TypeError to be consistent with DEBUG_MODE
// This is a TypeError to be consistent with DEBUG_MODE
// Note: DEBUG_MODE also tells the name of the container
// Note: DEBUG_MODE also tells the name of the container
if (PyArray_TYPE((PyArrayObject*) py_
%(name)
s) !=
%(type_num)
s)
{
if (PyArray_TYPE((PyArrayObject*) py_
{name}) != {type_num}) {
{
PyErr_Format(PyExc_TypeError,
PyErr_Format(PyExc_TypeError,
"expected type_num
%
%
d (
%(type_num)
s) got
%
%
d",
"expected type_num
%
d ({type_num}) got
%
d",
%(type_num)
s, PyArray_TYPE((PyArrayObject*) py_
%(name)
s
));
{type_num}, PyArray_TYPE((PyArrayObject*) py_{name}
));
%(fail)
s
{fail}
}
}
}
"""
%
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]
)
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
])
)
else
:
else
:
check
=
""
check
=
""
return
(
return
(
check
check
+
"""
+
"""
%(name)
s = (PyArrayObject*)(py_
%(name)
s);
{name} = (PyArrayObject*)(py_{name});
Py_XINCREF(
%(name)
s);
Py_XINCREF({name});
"""
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]))
%
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
])
)
)
def
c_cleanup
(
self
,
name
,
sub
):
def
c_cleanup
(
self
,
name
,
sub
):
return
"""
return
"""
if (
%(name)
s)
{
if (
{name}) {
{
Py_XDECREF(
%(name)
s
);
Py_XDECREF(
{name}
);
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
type_num
=
self
.
dtype_specs
()[
2
]
type_num
=
self
.
dtype_specs
()[
2
]
return
"""
return
"""
{
Py_XDECREF(py_
%(name)
s);
}
{
{Py_XDECREF(py_{name});}
}
if (!
%(name)
s)
{
if (!
{name}) {
{
Py_INCREF(Py_None);
Py_INCREF(Py_None);
py_
%(name)
s
= Py_None;
py_
{name}
= Py_None;
}
}
}
else if ((void*)py_
%(name)
s != (void*)
%(name)
s)
{
else if ((void*)py_
{name} != (void*){name}) {
{
py_
%(name)
s = (PyObject*)
%(name)
s
;
py_
{name} = (PyObject*){name}
;
}
}
}
{
Py_XINCREF(py_
%(name)
s);
}
{
{Py_XINCREF(py_{name});}
}
if (
%(name)
s && !PyArray_ISALIGNED((PyArrayObject*) py_
%(name)
s))
{
if (
{name} && !PyArray_ISALIGNED((PyArrayObject*) py_{name})) {
{
PyErr_Format(PyExc_NotImplementedError,
PyErr_Format(PyExc_NotImplementedError,
"c_sync: expected an aligned array, got non-aligned array of type
%
%
ld"
"c_sync: expected an aligned array, got non-aligned array of type
%
ld"
" with
%
%
ld dimensions, with 3 last dims "
" with
%
ld dimensions, with 3 last dims "
"
%
%
ld,
%%
ld,
%
%
ld"
"
%
ld,
%
ld,
%
ld"
" and 3 last strides
%
%
ld
%%
ld,
%
%
ld.",
" and 3 last strides
%
ld
%
ld,
%
ld.",
(long int) PyArray_TYPE((PyArrayObject*) py_
%(name)
s
),
(long int) PyArray_TYPE((PyArrayObject*) py_
{name}
),
(long int) PyArray_NDIM(
%(name)
s
),
(long int) PyArray_NDIM(
{name}
),
(long int) (PyArray_NDIM(
%(name)
s
) >= 3 ?
(long int) (PyArray_NDIM(
{name}
) >= 3 ?
PyArray_DIMS(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-3] : -1),
PyArray_DIMS(
{name})[PyArray_NDIM({name}
)-3] : -1),
(long int) (PyArray_NDIM(
%(name)
s
) >= 2 ?
(long int) (PyArray_NDIM(
{name}
) >= 2 ?
PyArray_DIMS(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-2] : -1),
PyArray_DIMS(
{name})[PyArray_NDIM({name}
)-2] : -1),
(long int) (PyArray_NDIM(
%(name)
s
) >= 1 ?
(long int) (PyArray_NDIM(
{name}
) >= 1 ?
PyArray_DIMS(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-1] : -1),
PyArray_DIMS(
{name})[PyArray_NDIM({name}
)-1] : -1),
(long int) (PyArray_NDIM(
%(name)
s
) >= 3 ?
(long int) (PyArray_NDIM(
{name}
) >= 3 ?
PyArray_STRIDES(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-3] : -1),
PyArray_STRIDES(
{name})[PyArray_NDIM({name}
)-3] : -1),
(long int) (PyArray_NDIM(
%(name)
s
) >= 2 ?
(long int) (PyArray_NDIM(
{name}
) >= 2 ?
PyArray_STRIDES(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-2] : -1),
PyArray_STRIDES(
{name})[PyArray_NDIM({name}
)-2] : -1),
(long int) (PyArray_NDIM(
%(name)
s
) >= 1 ?
(long int) (PyArray_NDIM(
{name}
) >= 1 ?
PyArray_STRIDES(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-1] : -1)
PyArray_STRIDES(
{name})[PyArray_NDIM({name}
)-1] : -1)
);
);
%(fail)
s
{fail}
}
}
}
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
def
c_headers
(
self
,
**
kwargs
):
def
c_headers
(
self
,
**
kwargs
):
return
ps
.
get_scalar_type
(
self
.
dtype
)
.
c_headers
(
**
kwargs
)
return
ps
.
get_scalar_type
(
self
.
dtype
)
.
c_headers
(
**
kwargs
)
...
...
pytensor/typed_list/basic.py
浏览文件 @
06c5acdf
...
@@ -103,12 +103,12 @@ class GetItem(COp):
...
@@ -103,12 +103,12 @@ class GetItem(COp):
output_name
=
out
[
0
]
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
return
"""
return
"""
%(output_name)
s = (typeof
%(output_name)
s) PyList_GetItem( (PyObject*)
%(x_name)
s, *((npy_int64 *) PyArray_DATA(
%(index)
s
)));
{output_name} = (typeof {output_name}) PyList_GetItem( (PyObject*) {x_name}, *((npy_int64 *) PyArray_DATA({index}
)));
if(
%(output_name)
s == NULL)
{
if(
{output_name} == NULL){
{
%(fail)
s
{fail}
}
}
}
Py_INCREF(
%(output_name)
s
);
Py_INCREF(
{output_name}
);
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,)
return
(
1
,)
...
@@ -170,8 +170,8 @@ class Append(COp):
...
@@ -170,8 +170,8 @@ class Append(COp):
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
if
not
self
.
inplace
:
init
=
"""
init
=
"""
%(output_name)
s = (PyListObject*) PyList_GetSlice((PyObject*)
%(x_name)
s, 0, PyList_GET_SIZE((PyObject*)
%(x_name)
s
)) ;
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name}
)) ;
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
else
:
else
:
init
=
f
"""
init
=
f
"""
{output_name} = {x_name};
{output_name} = {x_name};
...
@@ -179,15 +179,14 @@ class Append(COp):
...
@@ -179,15 +179,14 @@ class Append(COp):
return
(
return
(
init
init
+
"""
+
"""
if(
%(output_name)
s==NULL){
if({output_name}==NULL){{
%(fail)
s
{fail}
};
}};
if(PyList_Append( (PyObject*)
%(output_name)
s,(PyObject*)
%(toAppend)
s)){
if(PyList_Append( (PyObject*) {output_name},(PyObject*) {toAppend})){{
%(fail)
s
{fail}
};
}};
Py_INCREF(
%(output_name)
s);
Py_INCREF({output_name});
"""
"""
.
format
(
**
locals
())
%
locals
()
)
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
...
@@ -252,8 +251,8 @@ class Extend(COp):
...
@@ -252,8 +251,8 @@ class Extend(COp):
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
if
not
self
.
inplace
:
init
=
"""
init
=
"""
%(output_name)
s = (PyListObject*) PyList_GetSlice((PyObject*)
%(x_name)
s, 0, PyList_GET_SIZE((PyObject*)
%(x_name)
s
)) ;
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name}
)) ;
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
else
:
else
:
init
=
f
"""
init
=
f
"""
{output_name} = {x_name};
{output_name} = {x_name};
...
@@ -262,18 +261,17 @@ class Extend(COp):
...
@@ -262,18 +261,17 @@ class Extend(COp):
init
init
+
"""
+
"""
int i =0;
int i =0;
int length = PyList_GET_SIZE((PyObject*)
%(toAppend)
s);
int length = PyList_GET_SIZE((PyObject*) {toAppend});
if(
%(output_name)
s==NULL){
if({output_name}==NULL){{
%(fail)
s
{fail}
};
}};
for(i; i < length; i++){
for(i; i < length; i++){{
if(PyList_Append( (PyObject*)
%(output_name)
s,(PyObject*) PyList_GetItem((PyObject*)
%(toAppend)
s,i))==-1){
if(PyList_Append( (PyObject*) {output_name},(PyObject*) PyList_GetItem((PyObject*) {toAppend},i))==-1){{
%(fail)
s
{fail}
};
}};
}
}}
Py_INCREF(
%(output_name)
s);
Py_INCREF({output_name});
"""
"""
.
format
(
**
locals
())
%
locals
()
)
)
def
c_code_cache_version_
(
self
):
def
c_code_cache_version_
(
self
):
...
@@ -341,8 +339,8 @@ class Insert(COp):
...
@@ -341,8 +339,8 @@ class Insert(COp):
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
if
not
self
.
inplace
:
init
=
"""
init
=
"""
%(output_name)
s = (PyListObject*) PyList_GetSlice((PyObject*)
%(x_name)
s, 0, PyList_GET_SIZE((PyObject*)
%(x_name)
s
)) ;
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name}
)) ;
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
else
:
else
:
init
=
f
"""
init
=
f
"""
{output_name} = {x_name};
{output_name} = {x_name};
...
@@ -350,15 +348,14 @@ class Insert(COp):
...
@@ -350,15 +348,14 @@ class Insert(COp):
return
(
return
(
init
init
+
"""
+
"""
if(
%(output_name)
s==NULL){
if({output_name}==NULL){{
%(fail)
s
{fail}
};
}};
if(PyList_Insert((PyObject*)
%(output_name)
s, *((npy_int64 *) PyArray_DATA(
%(index)
s)), (PyObject*)
%(toInsert)
s)==-1){
if(PyList_Insert((PyObject*) {output_name}, *((npy_int64 *) PyArray_DATA({index})), (PyObject*) {toInsert})==-1){{
%(fail)
s
{fail}
};
}};
Py_INCREF(
%(output_name)
s);
Py_INCREF({output_name});
"""
"""
.
format
(
**
locals
())
%
locals
()
)
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
...
@@ -470,8 +467,8 @@ class Reverse(COp):
...
@@ -470,8 +467,8 @@ class Reverse(COp):
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
if
not
self
.
inplace
:
init
=
"""
init
=
"""
%(output_name)
s = (PyListObject*) PyList_GetSlice((PyObject*)
%(x_name)
s, 0, PyList_GET_SIZE((PyObject*)
%(x_name)
s
)) ;
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name}
)) ;
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
else
:
else
:
init
=
f
"""
init
=
f
"""
{output_name} = {x_name};
{output_name} = {x_name};
...
@@ -479,15 +476,14 @@ class Reverse(COp):
...
@@ -479,15 +476,14 @@ class Reverse(COp):
return
(
return
(
init
init
+
"""
+
"""
if(
%(output_name)
s==NULL){
if({output_name}==NULL){{
%(fail)
s
{fail}
};
}};
if(PyList_Reverse((PyObject*)
%(output_name)
s)==-1){
if(PyList_Reverse((PyObject*) {output_name})==-1){{
%(fail)
s
{fail}
};
}};
Py_INCREF(
%(output_name)
s);
Py_INCREF({output_name});
"""
"""
.
format
(
**
locals
())
%
locals
()
)
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
...
@@ -602,11 +598,11 @@ class Length(COp):
...
@@ -602,11 +598,11 @@ class Length(COp):
output_name
=
out
[
0
]
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
return
"""
return
"""
if(!
%(output_name)
s
)
if(!
{output_name}
)
%(output_name)
s
=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
{output_name}
=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA(
%(output_name)
s))[0]=PyList_Size((PyObject*)
%(x_name)
s
);
((npy_int64*)PyArray_DATA(
{output_name}))[0]=PyList_Size((PyObject*){x_name}
);
Py_INCREF(
%(output_name)
s
);
Py_INCREF(
{output_name}
);
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,)
return
(
1
,)
...
...
pytensor/typed_list/type.py
浏览文件 @
06c5acdf
...
@@ -111,26 +111,25 @@ class TypedListType(CType):
...
@@ -111,26 +111,25 @@ class TypedListType(CType):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
if
check_input
:
if
check_input
:
pre
=
"""
pre
=
"""
if (!PyList_Check(py_
%(name)
s))
{
if (!PyList_Check(py_
{name})) {
{
PyErr_SetString(PyExc_TypeError, "expected a list");
PyErr_SetString(PyExc_TypeError, "expected a list");
%(fail)
s
{fail}
}
"""
%
dict
(
name
=
name
,
fail
=
sub
[
"fail"
]
)
}
}"""
.
format
(
**
dict
(
name
=
name
,
fail
=
sub
[
"fail"
])
)
else
:
else
:
pre
=
""
pre
=
""
return
(
return
(
pre
pre
+
"""
+
"""
%(name)
s = (PyListObject*) (py_
%(name)
s);
{name} = (PyListObject*) (py_{name});
"""
"""
.
format
(
**
dict
(
name
=
name
,
fail
=
sub
[
"fail"
]))
%
dict
(
name
=
name
,
fail
=
sub
[
"fail"
])
)
)
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
return
"""
return
"""
Py_XDECREF(py_
%(name)
s
);
Py_XDECREF(py_
{name}
);
py_
%(name)
s = (PyObject*)(
%(name)
s
);
py_
{name} = (PyObject*)({name}
);
Py_INCREF(py_
%(name)
s
);
Py_INCREF(py_
{name}
);
"""
%
dict
(
name
=
name
)
"""
.
format
(
**
dict
(
name
=
name
)
)
def
c_cleanup
(
self
,
name
,
sub
):
def
c_cleanup
(
self
,
name
,
sub
):
return
""
return
""
...
...
tests/compile/test_debugmode.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/link/c/test_basic.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/link/c/test_op.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/link/c/test_params_type.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/link/c/test_type.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/tensor/conv/c_conv3d_corr3d_ref.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/tensor/conv/c_conv_corr_ref.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/tensor/utils.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论