Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
06c5acdf
提交
06c5acdf
authored
1月 26, 2024
作者:
Virgile Andreani
提交者:
Ricardo Vieira
1月 27, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix UP031 with `ruff --unsafe-fixes`
上级
cada5ad2
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
36 个修改的文件
包含
598 行增加
和
606 行删除
+598
-606
pyproject.toml
pyproject.toml
+1
-1
builders.py
pytensor/compile/builders.py
+1
-1
basic.py
pytensor/link/c/basic.py
+56
-62
cmodule.py
pytensor/link/c/cmodule.py
+5
-6
interface.py
pytensor/link/c/interface.py
+13
-11
op.py
pytensor/link/c/op.py
+16
-14
params_type.py
pytensor/link/c/params_type.py
+69
-70
type.py
pytensor/link/c/type.py
+53
-51
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+2
-2
printing.py
pytensor/printing.py
+1
-2
basic.py
pytensor/scalar/basic.py
+24
-20
math.py
pytensor/scalar/math.py
+10
-10
op.py
pytensor/scan/op.py
+2
-4
basic.py
pytensor/sparse/basic.py
+0
-0
rewriting.py
pytensor/sparse/rewriting.py
+0
-0
basic.py
pytensor/tensor/basic.py
+0
-0
blas.py
pytensor/tensor/blas.py
+58
-59
blas_c.py
pytensor/tensor/blas_c.py
+0
-0
blas_headers.py
pytensor/tensor/blas_headers.py
+23
-23
elemwise.py
pytensor/tensor/elemwise.py
+41
-41
elemwise_cgen.py
pytensor/tensor/elemwise_cgen.py
+23
-23
extra_ops.py
pytensor/tensor/extra_ops.py
+63
-63
math.py
pytensor/tensor/math.py
+5
-5
subtensor.py
pytensor/tensor/subtensor.py
+0
-0
type.py
pytensor/tensor/type.py
+69
-70
basic.py
pytensor/typed_list/basic.py
+54
-58
type.py
pytensor/typed_list/type.py
+9
-10
test_debugmode.py
tests/compile/test_debugmode.py
+0
-0
test_basic.py
tests/link/c/test_basic.py
+0
-0
test_op.py
tests/link/c/test_op.py
+0
-0
test_params_type.py
tests/link/c/test_params_type.py
+0
-0
test_type.py
tests/link/c/test_type.py
+0
-0
c_conv3d_corr3d_ref.py
tests/tensor/conv/c_conv3d_corr3d_ref.py
+0
-0
c_conv_corr_ref.py
tests/tensor/conv/c_conv_corr_ref.py
+0
-0
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+0
-0
utils.py
tests/tensor/utils.py
+0
-0
没有找到文件。
pyproject.toml
浏览文件 @
06c5acdf
...
...
@@ -130,7 +130,7 @@ disable = ["C0330", "C0326"]
[tool.ruff]
select
=
[
"C"
,
"E"
,
"F"
,
"I"
,
"UP"
,
"W"
]
ignore
=
[
"C408"
,
"C901"
,
"E501"
,
"E741"
,
"UP031"
]
ignore
=
[
"C408"
,
"C901"
,
"E501"
,
"E741"
]
exclude
=
[
"doc/"
,
"pytensor/_version.py"
,
"bin/pytensor_cache.py"
]
...
...
pytensor/compile/builders.py
浏览文件 @
06c5acdf
...
...
@@ -465,7 +465,7 @@ class OpFromGraph(Op, HasInnerGraph):
def
__str__
(
self
):
name
=
self
.
__class__
.
__name__
if
self
.
name
is
None
else
self
.
name
is_inline
=
self
.
is_inline
return
"
%(name)
s{inline=
%(is_inline)
s}"
%
locals
(
)
return
"
{name}{{inline={is_inline}}}"
.
format
(
**
locals
()
)
@config.change_flags
(
compute_test_value
=
"off"
)
def
_recompute_lop_op
(
self
):
...
...
pytensor/link/c/basic.py
浏览文件 @
06c5acdf
...
...
@@ -250,33 +250,30 @@ def struct_gen(args, struct_builders, blocks, sub):
# that holds the type, the value and the traceback. After storing
# the error, we return the failure code so we know which code
# block failed.
do_return
=
(
"""
if (
%(failure_var)
s) {
do_return
=
"""
if ({failure_var}) {{
// When there is a failure, this code puts the exception
// in __ERROR.
PyObject* err_type = NULL;
PyObject* err_msg = NULL;
PyObject* err_traceback = NULL;
PyErr_Fetch(&err_type, &err_msg, &err_traceback);
if (!err_type) {
err_type = Py_None;Py_INCREF(Py_None);
}
if (!err_msg) {
err_msg = Py_None; Py_INCREF(Py_None);
}
if (!err_traceback) {
err_traceback = Py_None; Py_INCREF(Py_None);
}
if (!err_type) {
{err_type = Py_None;Py_INCREF(Py_None);}
}
if (!err_msg) {
{err_msg = Py_None; Py_INCREF(Py_None);}
}
if (!err_traceback) {
{err_traceback = Py_None; Py_INCREF(Py_None);}
}
PyObject* old_err_type = PyList_GET_ITEM(__ERROR, 0);
PyObject* old_err_msg = PyList_GET_ITEM(__ERROR, 1);
PyObject* old_err_traceback = PyList_GET_ITEM(__ERROR, 2);
PyList_SET_ITEM(__ERROR, 0, err_type);
PyList_SET_ITEM(__ERROR, 1, err_msg);
PyList_SET_ITEM(__ERROR, 2, err_traceback);
{
Py_XDECREF(old_err_type);
}
{
Py_XDECREF(old_err_msg);
}
{
Py_XDECREF(old_err_traceback);
}
}
{
{Py_XDECREF(old_err_type);}
}
{
{Py_XDECREF(old_err_msg);}
}
{
{Py_XDECREF(old_err_traceback);}
}
}
}
// The failure code is returned to index what code block failed.
return
%(failure_var)
s;
"""
%
sub
)
return {failure_var};
"""
.
format
(
**
sub
)
sub
=
dict
(
sub
)
sub
.
update
(
locals
())
...
...
@@ -284,16 +281,15 @@ def struct_gen(args, struct_builders, blocks, sub):
# TODO: add some error checking to make sure storage_<x> are
# 1-element lists and __ERROR is a 3-elements list.
struct_code
=
(
"""
namespace {
struct
%(name)
s {
struct_code
=
"""
namespace {{
struct {name} {{
PyObject* __ERROR;
%(storage_decl)
s
%(struct_decl)
s
{storage_decl}
{struct_decl}
%(name)
s()
{
{name}() {
{
// This is only somewhat safe because we:
// 1) Are not a virtual class
// 2) Do not use any virtual classes in the members
...
...
@@ -306,32 +302,30 @@ def struct_gen(args, struct_builders, blocks, sub):
#ifndef PYTENSOR_DONT_MEMSET_STRUCT
memset(this, 0, sizeof(*this));
#endif
}
~
%(name)
s(void)
{
}
}
~
{name}(void) {
{
cleanup();
}
}
}
int init(PyObject* __ERROR,
%(args_decl)
s)
{
%(storage_incref)
s
%(storage_set)
s
%(struct_init_head)
s
int init(PyObject* __ERROR,
{args_decl}) {
{
{storage_incref}
{storage_set}
{struct_init_head}
this->__ERROR = __ERROR;
return 0;
}
void cleanup(void) {
%(struct_cleanup)
s
%(storage_decref)
s
}
int run(void) {
int
%(failure_var)
s = 0;
%(behavior)
s
%(do_return)
s
}
};
}
"""
%
sub
)
}}
void cleanup(void) {{
{struct_cleanup}
{storage_decref}
}}
int run(void) {{
int {failure_var} = 0;
{behavior}
{do_return}
}}
}};
}}
"""
.
format
(
**
sub
)
return
struct_code
...
...
@@ -380,9 +374,9 @@ def get_c_init(fgraph, r, name, sub):
pre
=
(
""
"""
py_
%(name)
s
= Py_None;
{
Py_XINCREF(py_
%(name)
s);
}
"""
%
locals
(
)
py_
{name}
= Py_None;
{
{Py_XINCREF(py_{name});}
}
"""
.
format
(
**
locals
()
)
)
return
pre
+
r
.
type
.
c_init
(
name
,
sub
)
...
...
@@ -418,9 +412,9 @@ def get_c_extract(fgraph, r, name, sub):
c_extract
=
r
.
type
.
c_extract
(
name
,
sub
,
False
)
pre
=
"""
py_
%(name)
s = PyList_GET_ITEM(storage_
%(name)
s
, 0);
{
Py_XINCREF(py_
%(name)
s);
}
"""
%
locals
(
)
py_
{name} = PyList_GET_ITEM(storage_{name}
, 0);
{
{Py_XINCREF(py_{name});}
}
"""
.
format
(
**
locals
()
)
return
pre
+
c_extract
...
...
@@ -447,9 +441,9 @@ def get_c_extract_out(fgraph, r, name, sub):
c_extract
=
r
.
type
.
c_extract_out
(
name
,
sub
,
check_input
,
check_broadcast
=
False
)
pre
=
"""
py_
%(name)
s = PyList_GET_ITEM(storage_
%(name)
s
, 0);
{
Py_XINCREF(py_
%(name)
s);
}
"""
%
locals
(
)
py_
{name} = PyList_GET_ITEM(storage_{name}
, 0);
{
{Py_XINCREF(py_{name});}
}
"""
.
format
(
**
locals
()
)
return
pre
+
c_extract
...
...
@@ -459,8 +453,8 @@ def get_c_cleanup(fgraph, r, name, sub):
"""
post
=
"""
{
Py_XDECREF(py_
%(name)
s);
}
"""
%
locals
(
)
{
{Py_XDECREF(py_{name});}
}
"""
.
format
(
**
locals
()
)
return
r
.
type
.
c_cleanup
(
name
,
sub
)
+
post
...
...
@@ -470,14 +464,14 @@ def get_c_sync(fgraph, r, name, sub):
"""
return
"""
if (!
%(failure_var)
s)
{
%(sync)
s
PyObject* old = PyList_GET_ITEM(storage_
%(name)
s
, 0);
{
Py_XINCREF(py_
%(name)
s);
}
PyList_SET_ITEM(storage_
%(name)
s, 0, py_
%(name)
s
);
{
Py_XDECREF(old);
}
}
"""
%
dict
(
sync
=
r
.
type
.
c_sync
(
name
,
sub
),
name
=
name
,
**
sub
)
if (!
{failure_var}) {
{
{sync}
PyObject* old = PyList_GET_ITEM(storage_
{name}
, 0);
{
{Py_XINCREF(py_{name});}
}
PyList_SET_ITEM(storage_
{name}, 0, py_{name}
);
{
{Py_XDECREF(old);}
}
}
}
"""
.
format
(
**
dict
(
sync
=
r
.
type
.
c_sync
(
name
,
sub
),
name
=
name
,
**
sub
)
)
def
apply_policy
(
fgraph
,
policy
,
r
,
name
,
sub
):
...
...
pytensor/link/c/cmodule.py
浏览文件 @
06c5acdf
...
...
@@ -1950,14 +1950,13 @@ class Compiler:
code
=
(
"""
%(preamble)
s
{preamble}
int main(int argc, char** argv)
{
%(body)
s
{
{
{body}
return 0;
}
"""
%
locals
()
}}
"""
.
format
(
**
locals
())
)
.
encode
()
return
cls
.
_try_compile_tmp
(
code
,
...
...
pytensor/link/c/interface.py
浏览文件 @
06c5acdf
...
...
@@ -558,18 +558,20 @@ class CLinkerType(CLinkerObject):
"""
return
"""
if (py_
%(name)
s
== Py_None)
{
%(c_init_code)
s
}
if (py_
{name}
== Py_None)
{
{
{c_init_code}
}
}
else
{
%(c_extract_code)
s
}
"""
%
dict
(
name
=
name
,
c_init_code
=
self
.
c_init
(
name
,
sub
),
c_extract_code
=
self
.
c_extract
(
name
,
sub
,
check_input
),
{{
{c_extract_code}
}}
"""
.
format
(
**
dict
(
name
=
name
,
c_init_code
=
self
.
c_init
(
name
,
sub
),
c_extract_code
=
self
.
c_extract
(
name
,
sub
,
check_input
),
)
)
def
c_cleanup
(
self
,
name
:
str
,
sub
:
dict
[
str
,
str
])
->
str
:
...
...
pytensor/link/c/op.py
浏览文件 @
06c5acdf
...
...
@@ -596,20 +596,22 @@ class ExternalCOp(COp):
# Generate the C code
return
"""
%(define_macros)
s
{
if (
%(func_name)
s(
%(func_args)
s
%(params)
s) != 0) {
%(fail)
s
}
}
%(undef_macros)
s
"""
%
dict
(
func_name
=
self
.
func_name
,
fail
=
sub
[
"fail"
],
params
=
params
,
func_args
=
self
.
format_c_function_args
(
inp
,
out
),
define_macros
=
define_macros
,
undef_macros
=
undef_macros
,
{define_macros}
{{
if ({func_name}({func_args}{params}) != 0) {{
{fail}
}}
}}
{undef_macros}
"""
.
format
(
**
dict
(
func_name
=
self
.
func_name
,
fail
=
sub
[
"fail"
],
params
=
params
,
func_args
=
self
.
format_c_function_args
(
inp
,
out
),
define_macros
=
define_macros
,
undef_macros
=
undef_macros
,
)
)
else
:
if
"code"
in
self
.
code_sections
:
...
...
pytensor/link/c/params_type.py
浏览文件 @
06c5acdf
...
...
@@ -359,8 +359,7 @@ class ParamsType(CType):
type_name
=
type_instance
.
__class__
.
__name__
if
not
isinstance
(
type_instance
,
CType
):
raise
TypeError
(
'ParamsType: attribute "
%
s" should inherit from PyTensor CType, got "
%
s".'
%
(
attribute_name
,
type_name
)
f
'ParamsType: attribute "{attribute_name}" should inherit from PyTensor CType, got "{type_name}".'
)
self
.
length
=
len
(
kwargs
)
...
...
@@ -723,15 +722,11 @@ class ParamsType(CType):
c_cleanup_list
.
append
(
type_instance
.
c_cleanup
(
attribute_name
,
sub
))
c_extract_list
.
append
(
"""
void extract_
%(attribute_name)
s(PyObject* py_
%(attribute_name)
s)
{
%(extract_code)
s
}
f
"""
void extract_
{attribute_name}(PyObject* py_{attribute_name}) {
{
{type_instance.c_extract(attribute_name, sub)}
}
}
"""
%
{
"attribute_name"
:
attribute_name
,
"extract_code"
:
type_instance
.
c_extract
(
attribute_name
,
sub
),
}
)
struct_declare
=
"
\n
"
.
join
(
c_declare_list
)
...
...
@@ -759,55 +754,57 @@ class ParamsType(CType):
)
)
final_struct_code
=
"""
/** ParamsType
%(struct_name)
s
**/
#ifndef
%(struct_name_defined)
s
#define
%(struct_name_defined)
s
struct
%(struct_name)
s
{
/** ParamsType
{struct_name}
**/
#ifndef
{struct_name_defined}
#define
{struct_name_defined}
struct
{struct_name} {
{
/* Attributes, */
int
%(struct_name)
s
_error;
%(struct_declare)
s
int
{struct_name}
_error;
{struct_declare}
/* Constructor. */
%(struct_name)
s()
{
%(struct_name)
s
_error = 0;
%(struct_init)
s
}
{struct_name}() {
{
{struct_name}
_error = 0;
{struct_init}
}
}
/* Destructor. */
~
%(struct_name)
s()
{
~
{struct_name}() {
{
// cleanup() is defined below.
cleanup();
}
}
}
/* Cleanup method. */
void cleanup() {
%(struct_cleanup)
s
}
void cleanup() {
{
{struct_cleanup}
}
}
/* Extraction methods. */
%(struct_extract)
s
{struct_extract}
/* Extract method. */
%(struct_extract_method)
s
{struct_extract_method}
/* Other methods. */
void setErrorOccurred() {
++
%(struct_name)
s
_error;
}
int errorOccurred() {
return
%(struct_name)
s
_error;
}
};
void setErrorOccurred() {
{
++
{struct_name}
_error;
}
}
int errorOccurred() {
{
return
{struct_name}
_error;
}
}
}
}
;
#endif
/** End ParamsType
%(struct_name)
s **/
"""
%
dict
(
struct_name_defined
=
struct_name_defined
,
struct_name
=
struct_name
,
struct_declare
=
struct_declare
,
struct_init
=
struct_init
,
struct_cleanup
=
struct_cleanup
,
struct_extract
=
struct_extract
,
struct_extract_method
=
struct_extract_method
,
/** End ParamsType {struct_name} **/
"""
.
format
(
**
dict
(
struct_name_defined
=
struct_name_defined
,
struct_name
=
struct_name
,
struct_declare
=
struct_declare
,
struct_init
=
struct_init
,
struct_cleanup
=
struct_cleanup
,
struct_extract
=
struct_extract
,
struct_extract_method
=
struct_extract_method
,
)
)
return
sorted
(
c_support_code_set
)
+
[
final_struct_code
]
...
...
@@ -822,8 +819,8 @@ class ParamsType(CType):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
return
"""
%(struct_name)
s*
%(name)
s
;
"""
%
dict
(
struct_name
=
self
.
name
,
name
=
name
)
{struct_name}* {name}
;
"""
.
format
(
**
dict
(
struct_name
=
self
.
name
,
name
=
name
)
)
def
c_init
(
self
,
name
,
sub
):
# NB: It seems c_init() is not called for an op param.
...
...
@@ -841,35 +838,37 @@ class ParamsType(CType):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
/* Seems c_init() is not called for a op param. So I call `new` here. */
%(name)
s = new
%(struct_name)
s
;
{name} = new {struct_name}
;
{ // This need a separate namespace for Clinker
const char* fields[] = {
%(fields_list)
s
};
if (py_
%(name)
s == Py_None)
{
{
{
// This need a separate namespace for Clinker
const char* fields[] = {
{{fields_list}}
};
if (py_
{name} == Py_None) {
{
PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None.");
%(fail)
s
}
for (int i = 0; i <
%(length)
s; ++i)
{
PyObject* o = PyDict_GetItemString(py_
%(name)
s
, fields[i]);
if (o == NULL) {
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute
\\
"
%
%
s
\\
" in object.", fields[i]);
%(fail)
s
}
%(name)
s
->extract(o, i);
if (
%(name)
s->errorOccurred())
{
{fail}
}
}
for (int i = 0; i <
{length}; ++i) {
{
PyObject* o = PyDict_GetItemString(py_
{name}
, fields[i]);
if (o == NULL) {
{
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute
\\
"
%
s
\\
" in object.", fields[i]);
{fail}
}
}
{name}
->extract(o, i);
if (
{name}->errorOccurred()) {
{
/* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */
fprintf(stderr, "
\\
nParamsType: error when extracting value for attribute
\\
"
%%
s
\\
".
\\
n", fields[i]);
%(fail)
s
}
}
}
"""
%
dict
(
name
=
name
,
struct_name
=
self
.
name
,
length
=
self
.
length
,
fail
=
sub
[
"fail"
],
fields_list
=
'"
%
s"'
%
'", "'
.
join
(
self
.
fields
),
fprintf(stderr, "
\\
nParamsType: error when extracting value for attribute
\\
"
%
s
\\
".
\\
n", fields[i]);
{fail}
}}
}}
}}
"""
.
format
(
**
dict
(
name
=
name
,
struct_name
=
self
.
name
,
length
=
self
.
length
,
fail
=
sub
[
"fail"
],
fields_list
=
'"
%
s"'
%
'", "'
.
join
(
self
.
fields
),
)
)
def
c_sync
(
self
,
name
,
sub
):
...
...
pytensor/link/c/type.py
浏览文件 @
06c5acdf
...
...
@@ -99,11 +99,11 @@ class Generic(CType, Singleton):
def
c_sync
(
self
,
name
,
sub
):
return
"""
assert(py_
%(name)
s
->ob_refcnt > 1);
Py_DECREF(py_
%(name)
s
);
py_
%(name)
s =
%(name)
s ?
%(name)
s
: Py_None;
Py_INCREF(py_
%(name)
s
);
"""
%
locals
(
)
assert(py_
{name}
->ob_refcnt > 1);
Py_DECREF(py_
{name}
);
py_
{name} = {name} ? {name}
: Py_None;
Py_INCREF(py_
{name}
);
"""
.
format
(
**
locals
()
)
def
c_code_cache_version
(
self
):
return
(
1
,)
...
...
@@ -191,17 +191,17 @@ class CDataType(CType[D]):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
return
"""
%(ctype)
s
%(name)
s
;
"""
%
dict
(
ctype
=
self
.
ctype
,
name
=
name
)
{ctype} {name}
;
"""
.
format
(
**
dict
(
ctype
=
self
.
ctype
,
name
=
name
)
)
def
c_init
(
self
,
name
,
sub
):
return
f
"{name} = NULL;"
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
%(name)
s = (
%(ctype)
s)PyCapsule_GetPointer(py_
%(name)
s
, NULL);
if (
%(name)
s == NULL)
%(fail)
s
"""
%
dict
(
name
=
name
,
ctype
=
self
.
ctype
,
fail
=
sub
[
"fail"
]
)
{name} = ({ctype})PyCapsule_GetPointer(py_{name}
, NULL);
if (
{name} == NULL) {fail}
"""
.
format
(
**
dict
(
name
=
name
,
ctype
=
self
.
ctype
,
fail
=
sub
[
"fail"
])
)
def
c_sync
(
self
,
name
,
sub
):
freefunc
=
self
.
freefunc
...
...
@@ -576,39 +576,39 @@ class EnumType(CType, dict):
"""
return
"""
#ifdef DEBUG
int pytensor_enum_to_string_
%(cname)
s(
%(ctype)
s in, char* out)
{
int pytensor_enum_to_string_
{cname}({ctype} in, char* out) {
{
int ret = 0;
switch(in) {
%(cases)
s
switch(in) {
{
{cases}
default:
PyErr_SetString(PyExc_ValueError, "
%(classname)
s
: unknown enum value.");
PyErr_SetString(PyExc_ValueError, "
{classname}
: unknown enum value.");
ret = -1;
break;
}
}
}
return ret;
}
}
}
#endif
"""
%
dict
(
cname
=
self
.
cname
,
ctype
=
self
.
ctype
,
classname
=
type
(
self
)
.
__name__
,
cases
=
""
.
join
(
"""
case
%(name)
s: sprintf(out, "
%(name)
s"); break;
"""
%
dict
(
name
=
name
)
for
name
in
self
),
"""
.
format
(
**
dict
(
cname
=
self
.
cname
,
ctype
=
self
.
ctype
,
classname
=
type
(
self
)
.
__name__
,
cases
=
""
.
join
(
"""
case {name}: sprintf(out, "{name}"); break;
"""
.
format
(
**
dict
(
name
=
name
))
for
name
in
self
),
)
)
def
c_support_code
(
self
,
**
kwargs
):
return
(
self
.
pyint_compat_code
+
""
.
join
(
"""
#define
%
s
%
s
f
"""
#define
{k} {str(self[k])}
"""
%
(
k
,
str
(
self
[
k
]))
for
k
in
sorted
(
self
.
keys
())
)
+
self
.
c_to_string
()
...
...
@@ -625,15 +625,15 @@ class EnumType(CType, dict):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
if (PyInt_Check(py_
%(name)
s))
{
%(name)
s = (
%(ctype)
s)PyInt_AsLong(py_
%(name)
s
);
}
else
{
%(name)
s = (
%(ctype)
s)PyFloat_AsDouble(py_
%(name)
s
);
}
if (PyErr_Occurred()) {
%(fail)
s
}
"""
%
dict
(
ctype
=
self
.
ctype
,
name
=
name
,
fail
=
sub
[
"fail"
]
)
if (PyInt_Check(py_
{name})) {
{
{name} = ({ctype})PyInt_AsLong(py_{name}
);
}
} else {
{
{name} = ({ctype})PyFloat_AsDouble(py_{name}
);
}
}
if (PyErr_Occurred()) {
{
{fail}
}
}
"""
.
format
(
**
dict
(
ctype
=
self
.
ctype
,
name
=
name
,
fail
=
sub
[
"fail"
])
)
def
c_code_cache_version
(
self
):
return
(
2
,
self
.
ctype
,
self
.
cname
,
tuple
(
self
.
items
()))
...
...
@@ -754,23 +754,25 @@ class CEnumType(EnumList):
# swapped_dict's keys are integers.
return
"""
switch(PyInt_AsLong(py_
%(name)
s))
{
%(cases)
s
switch(PyInt_AsLong(py_
{name})) {
{
{cases}
default:
PyErr_SetString(PyExc_ValueError, "CEnumType: invalid value to map to C constants.");
{
%(fail)
s
}
{
{{fail}}
}
break;
}
"""
%
dict
(
name
=
name
,
cases
=
""
.
join
(
"""
}}
"""
.
format
(
**
dict
(
name
=
name
,
cases
=
""
.
join
(
"""
case
%(i)
d:
%(name)
s =
%(constant_cname)
s; break;
"""
%
dict
(
i
=
i
,
name
=
name
,
constant_cname
=
swapped_dict
[
i
])
for
i
in
sorted
(
swapped_dict
.
keys
())
),
fail
=
sub
[
"fail"
],
%
dict
(
i
=
i
,
name
=
name
,
constant_cname
=
swapped_dict
[
i
])
for
i
in
sorted
(
swapped_dict
.
keys
())
),
fail
=
sub
[
"fail"
],
)
)
def
c_code_cache_version
(
self
):
...
...
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
06c5acdf
...
...
@@ -141,7 +141,7 @@ def _check_scipy_linalg_matrix(a, func_name):
if
isinstance
(
a
,
types
.
Optional
):
a
=
a
.
type
if
not
isinstance
(
a
,
types
.
Array
):
msg
=
"
%
s.
%
s() only supported for array types"
%
interp
msg
=
"
{}.{}() only supported for array types"
.
format
(
*
interp
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
a
.
ndim
not
in
[
1
,
2
]:
msg
=
"
%
s.
%
s() only supported on 1d or 2d arrays, found
%
s."
%
(
...
...
@@ -149,7 +149,7 @@ def _check_scipy_linalg_matrix(a, func_name):
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
not
isinstance
(
a
.
dtype
,
(
types
.
Float
,
types
.
Complex
)):
msg
=
"
%
s.
%
s() only supported on "
"float and complex arrays."
%
interp
msg
=
"
{}.{}() only supported on "
"float and complex arrays."
.
format
(
*
interp
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
...
...
pytensor/printing.py
浏览文件 @
06c5acdf
...
...
@@ -646,8 +646,7 @@ def _debugprint(
tot_time_percent
=
(
tot_time_dict
[
node
]
/
profile
.
fct_call_time
)
*
100
print
(
"
%
s -->
%8.2
es
%4.1
f
%% %8.2
es
%4.1
f
%%
"
%
(
"{} --> {:8.2e}s {:4.1f}
%
{:8.2e}s {:4.1f}
%
"
.
format
(
var_output
,
op_time
,
op_time_percent
,
...
...
pytensor/scalar/basic.py
浏览文件 @
06c5acdf
...
...
@@ -466,40 +466,44 @@ class ScalarType(CType, HasDataType, HasShape):
specs
=
self
.
dtype_specs
()
if
check_input
:
pre
=
"""
if (!PyObject_TypeCheck(py_
%(name)
s, &
%(pyarr_type)
s
))
{
if (!PyObject_TypeCheck(py_
{name}, &{pyarr_type}
))
{
{
PyErr_Format(PyExc_ValueError,
"Scalar check failed (
%(dtype)
s)");
%(fail)
s
}
"""
%
dict
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
pyarr_type
=
"Py
%
sArrType_Type"
%
specs
[
2
]
"Scalar check failed ({dtype})");
{fail}
}}
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
pyarr_type
=
"Py
%
sArrType_Type"
%
specs
[
2
],
)
)
else
:
pre
=
""
return
(
pre
+
"""
PyArray_ScalarAsCtype(py_
%(name)
s, &
%(name)
s);
"""
%
dict
(
sub
,
name
=
name
)
PyArray_ScalarAsCtype(py_{name}, &{name});
"""
.
format
(
**
dict
(
sub
,
name
=
name
))
)
def
c_sync
(
self
,
name
,
sub
):
specs
=
self
.
dtype_specs
()
return
"""
Py_XDECREF(py_
%(name)
s
);
py_
%(name)
s = PyArrayScalar_New(
%(cls)
s
);
if (!py_
%(name)
s
)
{
Py_XDECREF(py_
{name}
);
py_
{name} = PyArrayScalar_New({cls}
);
if (!py_
{name}
)
{
{
Py_XINCREF(Py_None);
py_
%(name)
s
= Py_None;
py_
{name}
= Py_None;
PyErr_Format(PyExc_MemoryError,
"Instantiation of new Python scalar failed (
%(dtype)
s
)");
%(fail)
s
}
PyArrayScalar_ASSIGN(py_
%(name)
s,
%(cls)
s,
%(name)
s
);
"""
%
dict
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
cls
=
specs
[
2
]
)
"Instantiation of new Python scalar failed (
{dtype}
)");
{fail}
}
}
PyArrayScalar_ASSIGN(py_
{name}, {cls}, {name}
);
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
cls
=
specs
[
2
])
)
def
c_cleanup
(
self
,
name
,
sub
):
return
""
...
...
pytensor/scalar/math.py
浏览文件 @
06c5acdf
...
...
@@ -620,8 +620,8 @@ class Chi2SF(BinaryScalarOp):
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""
%(z)
s
=
(
%(dtype)
s) 1 - GammaP(
%(k)
s/2.,
%(x)
s/2.);"""
%
locals
(
)
return
"""
{z}
=
(
{dtype}) 1 - GammaP({k}/2., {x}/2.);"""
.
format
(
**
locals
()
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
...
...
@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp):
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""
%(z)
s
=
(
%(dtype)
s) GammaP(
%(k)
s,
%(x)
s);"""
%
locals
(
)
return
"""
{z}
=
(
{dtype}) GammaP({k}, {x});"""
.
format
(
**
locals
()
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
...
...
@@ -712,8 +712,8 @@ class GammaIncC(BinaryScalarOp):
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""
%(z)
s
=
(
%(dtype)
s) GammaQ(
%(k)
s,
%(x)
s);"""
%
locals
(
)
return
"""
{z}
=
(
{dtype}) GammaQ({k}, {x});"""
.
format
(
**
locals
()
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
...
...
@@ -1018,8 +1018,8 @@ class GammaU(BinaryScalarOp):
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""
%(z)
s
=
(
%(dtype)
s) upperGamma(
%(k)
s,
%(x)
s);"""
%
locals
(
)
return
"""
{z}
=
(
{dtype}) upperGamma({k}, {x});"""
.
format
(
**
locals
()
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
...
...
@@ -1056,8 +1056,8 @@ class GammaL(BinaryScalarOp):
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""
%(z)
s
=
(
%(dtype)
s) lowerGamma(
%(k)
s,
%(x)
s);"""
%
locals
(
)
return
"""
{z}
=
(
{dtype}) lowerGamma({k}, {x});"""
.
format
(
**
locals
()
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
...
...
pytensor/scan/op.py
浏览文件 @
06c5acdf
...
...
@@ -3364,8 +3364,7 @@ def profile_printer(
total_scan_fct_time
+=
scan_fct_time
total_scan_op_time
+=
scan_op_time
print
(
"
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
"
%
(
" {:5.1f}s {:5.1f}s {:5.1f}s {:5.1f}
%
{:5.1f}
%
"
.
format
(
v
,
scan_fct_time
,
scan_op_time
,
...
...
@@ -3385,8 +3384,7 @@ def profile_printer(
print
(
" No scan have its inner profile enabled."
,
file
=
file
)
else
:
print
(
"total
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
"
%
(
"total {:5.1f}s {:5.1f}s {:5.1f}s {:5.1f}
%
{:5.1f}
%
"
.
format
(
total_super_scan_time
,
total_scan_fct_time
,
total_scan_op_time
,
...
...
pytensor/sparse/basic.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
pytensor/sparse/rewriting.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
pytensor/tensor/basic.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
pytensor/tensor/blas.py
浏览文件 @
06c5acdf
...
...
@@ -1840,19 +1840,19 @@ class BatchedDot(COp):
z_shape
=
", "
.
join
(
z_dims
)
z_contiguous
=
contiguous
(
_z
,
z_ndim
)
allocate
=
"""
if (NULL ==
%(_z)
s || !(
%(z_shape_correct)
s) || !(
%(z_contiguous)
s
))
{
npy_intp dims[
%(z_ndim)
s] = {
%(z_shape)
s
};
Py_XDECREF(
%(_z)
s
);
%(_z)
s
= (PyArrayObject*)PyArray_SimpleNew(
%(z_ndim)
s, dims, PyArray_TYPE(
%(_x)
s
));
if(!
%(_z)
s)
{
if (NULL ==
{_z} || !({z_shape_correct}) || !({z_contiguous}
))
{
{
npy_intp dims[
{z_ndim}] = {{{z_shape}}
};
Py_XDECREF(
{_z}
);
{_z}
= (PyArrayObject*)PyArray_SimpleNew(
{z_ndim}, dims, PyArray_TYPE({_x}
));
if(!
{_z}) {
{
PyErr_SetString(PyExc_MemoryError,
"failed to alloc BatchedDot output");
%(fail)
s
}
}
"""
%
locals
(
)
{fail}
}
}
}
}
"""
.
format
(
**
locals
()
)
# code to reallocate inputs contiguously if necessary
contiguate
=
[]
...
...
@@ -1860,76 +1860,75 @@ class BatchedDot(COp):
_contiguous
=
contiguous
(
var
,
ndim
)
contiguate
.
append
(
"""
if (!(
%(_contiguous)
s))
{
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(
%(var)
s
);
if (!(
{_contiguous})) {
{
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(
{var}
);
if (!_copy)
%(fail)
s
Py_XDECREF(
%(var)
s);
%(var)
s = _copy;
}
"""
%
locals
()
{fail}
Py_XDECREF({var});
{var} = _copy;
}}
"""
.
format
(
**
locals
())
)
contiguate
=
"
\n
"
.
join
(
contiguate
)
return
"""
int type_num = PyArray_DESCR(
%(_x)
s
)->type_num;
int type_size = PyArray_DESCR(
%(_x)
s
)->elsize; // in bytes
int type_num = PyArray_DESCR(
{_x}
)->type_num;
int type_size = PyArray_DESCR(
{_x}
)->elsize; // in bytes
if (PyArray_NDIM(
%(_x)
s) != 3)
{
if (PyArray_NDIM(
{_x}) != 3) {
{
PyErr_Format(PyExc_NotImplementedError,
"rank(x) != 3. rank(x) is
%
%
d.",
PyArray_NDIM(
%(_x)
s
));
%(fail)
s
;
}
if (PyArray_NDIM(
%(_y)
s) != 3)
{
"rank(x) != 3. rank(x) is
%
d.",
PyArray_NDIM(
{_x}
));
{fail}
;
}
}
if (PyArray_NDIM(
{_y}) != 3) {
{
PyErr_Format(PyExc_NotImplementedError,
"rank(y) != 3. rank(y) is
%
%
d.",
PyArray_NDIM(
%(_y)
s
));
%(fail)
s
;
}
if (
%(_z)
s && PyArray_NDIM(
%(_z)
s) != 3)
{
"rank(y) != 3. rank(y) is
%
d.",
PyArray_NDIM(
{_y}
));
{fail}
;
}
}
if (
{_z} && PyArray_NDIM({_z}) != 3) {
{
PyErr_Format(PyExc_NotImplementedError,
"rank(z) != 3. rank(z) is
%
%
d.",
PyArray_NDIM(
%(_z)
s
));
%(fail)
s
;
}
"rank(z) != 3. rank(z) is
%
d.",
PyArray_NDIM(
{_z}
));
{fail}
;
}
}
// allocate output
%(allocate)
s
{allocate}
// reallocate any noncontiguous arrays or arrays with invalid strides
%(contiguate)
s
{contiguate}
if ((PyArray_DESCR(
%(_x)
s
)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
%(_x)
s
)->type_num != NPY_FLOAT))
{
PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float");
%(fail)
s;
}
if ((PyArray_DESCR(
{_x}
)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
{_x}
)->type_num != NPY_FLOAT))
{
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); {fail};}
}
if ((PyArray_DESCR(
%(_y)
s
)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
%(_y)
s
)->type_num != NPY_FLOAT))
{
PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float");
%(fail)
s;
}
if ((PyArray_DESCR(
{_y}
)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
{_y}
)->type_num != NPY_FLOAT))
{
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); {fail};}
}
if ((PyArray_DESCR(
%(_z)
s
)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
%(_z)
s
)->type_num != NPY_FLOAT))
{
PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float");
%(fail)
s;
}
if ((PyArray_DESCR(
{_z}
)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
{_z}
)->type_num != NPY_FLOAT))
{
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); {fail};}
}
if ((PyArray_DESCR(
%(_x)
s)->type_num != PyArray_DESCR(
%(_y)
s
)->type_num)
||(PyArray_DESCR(
%(_x)
s)->type_num != PyArray_DESCR(
%(_z)
s
)->type_num))
{
PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same");
%(fail)
s;
}
if ((PyArray_DESCR(
{_x})->type_num != PyArray_DESCR({_y}
)->type_num)
||(PyArray_DESCR(
{_x})->type_num != PyArray_DESCR({_z}
)->type_num))
{
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); {fail}; }
}
switch (type_num)
{
{
{
case NPY_FLOAT:
if (batch_gemm<float>(sgemm_, type_size,
%(_x)
s,
%(_y)
s,
%(_z)
s))
{
%(fail)
s
;
}
if (batch_gemm<float>(sgemm_, type_size,
{_x}, {_y}, {_z})) {
{
{fail}
;
}
}
break;
case NPY_DOUBLE:
if (batch_gemm<double>(dgemm_, type_size,
%(_x)
s,
%(_y)
s,
%(_z)
s))
{
%(fail)
s
;
}
if (batch_gemm<double>(dgemm_, type_size,
{_x}, {_y}, {_z})) {
{
{fail}
;
}
}
break;
}
"""
%
locals
(
)
}
}
"""
.
format
(
**
locals
()
)
def
c_code_cache_version
(
self
):
from
pytensor.tensor.blas_headers
import
blas_header_version
...
...
pytensor/tensor/blas_c.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
pytensor/tensor/blas_headers.py
浏览文件 @
06c5acdf
...
...
@@ -1097,7 +1097,7 @@ def ____gemm_code(check_ab, a_init, b_init):
if (PyArray_NDIM(_y) != 2) goto _dot_execute_fallback;
if (PyArray_NDIM(_z) != 2) goto _dot_execute_fallback;
%(check_ab)
s
{check_ab}
if ((PyArray_DESCR(_x)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(_x)->type_num != NPY_FLOAT))
...
...
@@ -1117,16 +1117,16 @@ def ____gemm_code(check_ab, a_init, b_init):
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
{
{
error_string = "Input dimensions do not agree";
goto _dot_execute_fail;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0]
%(mod)
s type_size) || (Sx[1]
%(mod)
s
type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0]
%(mod)
s type_size) || (Sy[1]
%(mod)
s
type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0]
%(mod)
s type_size) || (Sz[1]
%(mod)
s
type_size))
{
}
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0]
{mod} type_size) || (Sx[1] {mod}
type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0]
{mod} type_size) || (Sy[1] {mod}
type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0]
{mod} type_size) || (Sz[1] {mod}
type_size))
{
{
goto _dot_execute_fallback;
}
}
}
/*
encode the stride structure of _x,_y,_z into a single integer
...
...
@@ -1146,19 +1146,19 @@ def ____gemm_code(check_ab, a_init, b_init):
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
{
{
case NPY_FLOAT:
{
{
{
#define REAL float
float a =
%(a_init)
s
;
float b =
%(b_init)
s
;
float a =
{a_init}
;
float b =
{b_init}
;
float* x = (float*)PyArray_DATA(_x);
float* y = (float*)PyArray_DATA(_y);
float* z = (float*)PyArray_DATA(_z);
switch(unit)
{
{
{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
...
...
@@ -1168,21 +1168,21 @@ def ____gemm_code(check_ab, a_init, b_init):
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
}
}
;
#undef REAL
}
}
}
break;
case NPY_DOUBLE:
{
{
{
#define REAL double
double a =
%(a_init)
s
;
double b =
%(b_init)
s
;
double a =
{a_init}
;
double b =
{b_init}
;
double* x = (double*)PyArray_DATA(_x);
double* y = (double*)PyArray_DATA(_y);
double* z = (double*)PyArray_DATA(_z);
switch(unit)
{
{
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
...
...
@@ -1192,11 +1192,11 @@ def ____gemm_code(check_ab, a_init, b_init):
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
}
}
;
#undef REAL
}
}
}
break;
}
}
}
return 0; //success!
...
...
@@ -1212,4 +1212,4 @@ def ____gemm_code(check_ab, a_init, b_init):
return -1;
/* v 1 */
"""
%
locals
(
)
"""
.
format
(
**
locals
()
)
pytensor/tensor/elemwise.py
浏览文件 @
06c5acdf
...
...
@@ -929,12 +929,12 @@ class Elemwise(OpenMPOp):
# decrease the reference of whatever the output contained
# prior to this
alloc
+=
"""
if (
%(oname)
s)
{
Py_XDECREF(
%(oname)
s
);
}
%(oname)
s =
%(iname)
s
;
Py_XINCREF(
%(oname)
s
);
"""
%
locals
(
)
if (
{oname}) {
{
Py_XDECREF(
{oname}
);
}
}
{oname} = {iname}
;
Py_XINCREF(
{oname}
);
"""
.
format
(
**
locals
()
)
# We alias the scalar variables
defines
+=
f
"#define {oname}_i {iname}_i
\n
"
undefs
+=
f
"#undef {oname}_i
\n
"
...
...
@@ -958,12 +958,12 @@ class Elemwise(OpenMPOp):
dict
(
sub
,
fail
=
fail
),
)
code
=
"""
{
%(defines)
s
%(task_code)
s
%(undefs)
s
}
"""
%
locals
(
)
{
{
{defines}
{task_code}
{undefs}
}
}
"""
.
format
(
**
locals
()
)
loop_orders
=
orders
+
[
list
(
range
(
nnested
))]
*
len
(
real_onames
)
dtypes
=
idtypes
+
list
(
real_odtypes
)
...
...
@@ -995,27 +995,27 @@ class Elemwise(OpenMPOp):
if
index
!=
"x"
:
preloops
.
setdefault
(
j
,
""
)
preloops
[
j
]
+=
(
"
%
%
(lv
%(i)
s)s_iter = (
%(dtype)
s
*)"
"(PyArray_DATA(
%
%
(lv
%(i)
s)s));
\n
"
%
locals
(
)
"
%
(lv{i})s_iter = ({dtype}
*)"
"(PyArray_DATA(
%
(lv{i})s));
\n
"
.
format
(
**
locals
()
)
)
%
sub
break
else
:
# all broadcastable
preloops
.
setdefault
(
0
,
""
)
preloops
[
0
]
+=
(
"
%
%
(lv
%(i)
s)s_iter = (
%(dtype)
s
*)"
"(PyArray_DATA(
%
%
(lv
%(i)
s)s));
\n
"
%
locals
(
)
"
%
(lv{i})s_iter = ({dtype}
*)"
"(PyArray_DATA(
%
(lv{i})s));
\n
"
.
format
(
**
locals
()
)
)
%
sub
init_array
=
preloops
.
get
(
0
,
" "
)
loop
=
"""
{
%(defines)
s
%(init_array)
s
%(task_decl)
s
%(task_code)
s
%(undefs)
s
}
"""
%
locals
(
)
{
{
{defines}
{init_array}
{task_decl}
{task_code}
{undefs}
}
}
"""
.
format
(
**
locals
()
)
else
:
loop
=
cgen
.
make_loop
(
loop_orders
=
loop_orders
,
...
...
@@ -1076,24 +1076,24 @@ class Elemwise(OpenMPOp):
for
x
,
var
in
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
):
if
not
all
(
s
==
1
for
s
in
var
.
type
.
shape
):
contig
+=
"""
dtype_
%(x)
s *
%(x)
s_ptr = (dtype_
%(x)
s*) PyArray_DATA(
%(x)
s
);
"""
%
locals
(
)
dtype_
{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x}
);
"""
.
format
(
**
locals
()
)
index
+=
"""
dtype_
%(x)
s&
%(x)
s_i =
%(x)
s
_ptr[i];
"""
%
locals
(
)
dtype_
{x}& {x}_i = {x}
_ptr[i];
"""
.
format
(
**
locals
()
)
else
:
contig
+=
"""
dtype_
%(x)
s&
%(x)
s_i = ((dtype_
%(x)
s*) PyArray_DATA(
%(x)
s
))[0];
"""
%
locals
(
)
dtype_
{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}
))[0];
"""
.
format
(
**
locals
()
)
if
self
.
openmp
:
contig
+=
f
"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)})
"""
contig
+=
"""
for(int i=0; i<n; i++){
%(index)
s
%(task_code)
s
;
}
"""
%
locals
(
)
for(int i=0; i<n; i++){
{
{index}
{task_code}
;
}
}
"""
.
format
(
**
locals
()
)
if
contig
is
not
None
:
z
=
list
(
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
))
all_broadcastable
=
all
(
s
==
1
for
s
in
var
.
type
.
shape
)
...
...
@@ -1112,12 +1112,12 @@ class Elemwise(OpenMPOp):
]
)
loop
=
"""
if((
%(cond1)
s) || (
%(cond2)
s))
{
%(contig)
s
}
else
{
%(loop)
s
}
"""
%
locals
(
)
if((
{cond1}) || ({cond2})){
{
{contig}
}
}else{
{
{loop}
}
}
"""
.
format
(
**
locals
()
)
return
decl
,
checks
,
alloc
,
loop
,
""
def
c_code
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
...
...
pytensor/tensor/elemwise_cgen.py
浏览文件 @
06c5acdf
...
...
@@ -176,34 +176,34 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
# contiguous dimensions, or the dimension with the smallest
# stride. Right now, it is allocated to be C_CONTIGUOUS.
return
"""
{
npy_intp dims[
%(nd)
s
];
//npy_intp* dims = (npy_intp*)malloc(
%(nd)
s
* sizeof(npy_intp));
%(init_dims)
s
if (!
%(olv)
s)
{
%(olv)
s = (PyArrayObject*)PyArray_EMPTY(
%(nd)
s
, dims,
%(type)
s
,
%(fortran)
s
);
}
else {
{
{
npy_intp dims[
{nd}
];
//npy_intp* dims = (npy_intp*)malloc(
{nd}
* sizeof(npy_intp));
{init_dims}
if (!
{olv}) {
{
{olv} = (PyArrayObject*)PyArray_EMPTY({nd}
, dims,
{type}
,
{fortran}
);
}
}
else {
{
PyArray_Dims new_dims;
new_dims.len =
%(nd)
s
;
new_dims.len =
{nd}
;
new_dims.ptr = dims;
PyObject* success = PyArray_Resize(
%(olv)
s
, &new_dims, 0, NPY_CORDER);
if (!success) {
PyObject* success = PyArray_Resize(
{olv}
, &new_dims, 0, NPY_CORDER);
if (!success) {
{
// If we can't resize the ndarray we have we can allocate a new one.
PyErr_Clear();
Py_XDECREF(
%(olv)
s
);
%(olv)
s = (PyArrayObject*)PyArray_EMPTY(
%(nd)
s, dims,
%(type)
s
, 0);
}
else
{
Py_XDECREF(
{olv}
);
{olv} = (PyArrayObject*)PyArray_EMPTY({nd}, dims, {type}
, 0);
}
} else {
{
Py_DECREF(success);
}
}
if (!
%(olv)
s)
{
%(fail)
s
}
}
"""
%
dict
(
locals
(),
**
sub
)
}
}
}
}
if (!
{olv}) {
{
{fail}
}
}
}
}
"""
.
format
(
**
dict
(
locals
(),
**
sub
)
)
def
make_loop
(
loop_orders
,
dtypes
,
loop_tasks
,
sub
,
openmp
=
None
):
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
06c5acdf
...
...
@@ -69,24 +69,24 @@ class CpuContiguous(COp):
(
x
,)
=
inames
(
y
,)
=
onames
code
=
"""
if (!PyArray_CHKFLAGS(
%(x)
s, NPY_ARRAY_C_CONTIGUOUS))
{
if (!PyArray_CHKFLAGS(
{x}, NPY_ARRAY_C_CONTIGUOUS)){
{
// check to see if output is contiguous first
if (
%(y)
s
!= NULL &&
PyArray_CompareLists(PyArray_DIMS(
%(y)
s), PyArray_DIMS(
%(x)
s), PyArray_NDIM(
%(x)
s
)) &&
PyArray_CHKFLAGS(
%(y)
s, NPY_ARRAY_C_CONTIGUOUS))
{
PyArray_CopyInto(
%(y)
s,
%(x)
s
);
}
else{
Py_XDECREF(
%(y)
s
);
%(y)
s = PyArray_GETCONTIGUOUS(
%(x)
s
);
}
}
else{
Py_XINCREF(
%(x)
s
);
Py_XDECREF(
%(y)
s
);
%(y)
s =
%(x)
s
;
}
"""
%
locals
(
)
if (
{y}
!= NULL &&
PyArray_CompareLists(PyArray_DIMS(
{y}), PyArray_DIMS({x}), PyArray_NDIM({x}
)) &&
PyArray_CHKFLAGS(
{y}, NPY_ARRAY_C_CONTIGUOUS)){
{
PyArray_CopyInto(
{y}, {x}
);
}
}
else{
{
Py_XDECREF(
{y}
);
{y} = PyArray_GETCONTIGUOUS({x}
);
}
}
}
}
else{
{
Py_XINCREF(
{x}
);
Py_XDECREF(
{y}
);
{y} = {x}
;
}
}
"""
.
format
(
**
locals
()
)
return
code
def
c_code_cache_version
(
self
):
...
...
@@ -162,12 +162,12 @@ class SearchsortedOp(COp):
side
=
sub
[
"params"
]
fail
=
sub
[
"fail"
]
return
"""
PyObject* tmp_
%(name)
s
= PyUnicode_FromString("right");
if (tmp_
%(name)
s
== NULL)
%(fail)
s
;
right_
%(name)
s = PyUnicode_Compare(
%(side)
s, tmp_
%(name)
s
);
Py_DECREF(tmp_
%(name)
s
);
"""
%
locals
(
)
PyObject* tmp_
{name}
= PyUnicode_FromString("right");
if (tmp_
{name}
== NULL)
{fail}
;
right_
{name} = PyUnicode_Compare({side}, tmp_{name}
);
Py_DECREF(tmp_
{name}
);
"""
.
format
(
**
locals
()
)
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
sorter
=
None
...
...
@@ -181,17 +181,17 @@ class SearchsortedOp(COp):
fail
=
sub
[
"fail"
]
return
"""
Py_XDECREF(
%(z)
s
);
%(z)
s = (PyArrayObject*) PyArray_SearchSorted(
%(x)
s, (PyObject*)
%(v)
s
,
right_
%(name)
s ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*)
%(sorter)
s
);
if (!
%(z)
s
)
%(fail)
s
;
if (PyArray_TYPE(
%(z)
s) != NPY_INT64)
{
PyObject * tmp = PyArray_Cast(
%(z)
s
, NPY_INT64);
Py_XDECREF(
%(z)
s
);
%(z)
s
= (PyArrayObject*) tmp;
}
"""
%
locals
(
)
Py_XDECREF(
{z}
);
{z} = (PyArrayObject*) PyArray_SearchSorted({x}, (PyObject*) {v}
,
right_
{name} ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) {sorter}
);
if (!
{z}
)
{fail}
;
if (PyArray_TYPE(
{z}) != NPY_INT64){
{
PyObject * tmp = PyArray_Cast(
{z}
, NPY_INT64);
Py_XDECREF(
{z}
);
{z}
= (PyArrayObject*) tmp;
}
}
"""
.
format
(
**
locals
()
)
def
c_code_cache_version
(
self
):
return
(
2
,)
...
...
@@ -351,43 +351,43 @@ class CumOp(COp):
params
=
sub
[
"params"
]
code
=
"""
int axis =
%(params)
s
->c_axis;
if (axis == 0 && PyArray_NDIM(
%(x)
s
) == 1)
int axis =
{params}
->c_axis;
if (axis == 0 && PyArray_NDIM(
{x}
) == 1)
axis = NPY_MAXDIMS;
npy_intp shape[1] = {
PyArray_SIZE(
%(x)
s)
};
if(axis == NPY_MAXDIMS && !(
%(z)
s && PyArray_DIMS(
%(z)
s
)[0] == shape[0]))
{
Py_XDECREF(
%(z)
s
);
%(z)
s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_
%(x)
s
));
}
else if(axis != NPY_MAXDIMS && !(
%(z)
s && PyArray_CompareLists(PyArray_DIMS(
%(z)
s), PyArray_DIMS(
%(x)
s), PyArray_NDIM(
%(x)
s
))))
{
Py_XDECREF(
%(z)
s
);
%(z)
s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(
%(x)
s), PyArray_DIMS(
%(x)
s), PyArray_TYPE(
%(x)
s
));
}
if (!
%(z)
s
)
%(fail)
s
;
{
npy_intp shape[1] = {
{ PyArray_SIZE({x}) }
};
if(axis == NPY_MAXDIMS && !(
{z} && PyArray_DIMS({z}
)[0] == shape[0]))
{
{
Py_XDECREF(
{z}
);
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_{x}
));
}
}
else if(axis != NPY_MAXDIMS && !(
{z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}
))))
{
{
Py_XDECREF(
{z}
);
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}
));
}
}
if (!
{z}
)
{fail}
;
{
{
PyObject * t = NULL;
if(
%(params)
s
->mode == MODE_ADD)
if(
{params}
->mode == MODE_ADD)
t = PyArray_CumSum(
%(x)
s
, axis,
PyArray_TYPE(
%(x)
s),
%(z)
s
);
else if(
%(params)
s
->mode == MODE_MUL)
{x}
, axis,
PyArray_TYPE(
{x}), {z}
);
else if(
{params}
->mode == MODE_MUL)
t = PyArray_CumProd(
%(x)
s
, axis,
PyArray_TYPE(
%(x)
s),
%(z)
s
);
{x}
, axis,
PyArray_TYPE(
{x}), {z}
);
if (!t){
%(fail)
s
;
}
if (!t){
{
{fail}
;
}
}
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
}
"""
%
locals
(
)
}
}
"""
.
format
(
**
locals
()
)
return
code
...
...
pytensor/tensor/math.py
浏览文件 @
06c5acdf
...
...
@@ -420,13 +420,13 @@ class Argmax(COp):
raise
NotImplementedError
()
# params is only used here for now
axis_code
=
"""
axis =
%(params)
s
->c_axis;
if(axis > PyArray_NDIM(
%(x)
s)-1 || axis < -PyArray_NDIM(
%(x)
s))
{
axis =
{params}
->c_axis;
if(axis > PyArray_NDIM(
{x})-1 || axis < -PyArray_NDIM({x})){
{
PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument");
%(fail)
s
}
"""
%
locals
(
)
{fail}
}
}
"""
.
format
(
**
locals
()
)
ret
=
"""
int axis;
...
...
pytensor/tensor/subtensor.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
pytensor/tensor/type.py
浏览文件 @
06c5acdf
...
...
@@ -476,47 +476,47 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
if
check_input
:
check
=
"""
typedef
%(dtype)
s dtype_
%(name)
s
;
"""
%
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
]
)
typedef
{dtype} dtype_{name}
;
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
])
)
else
:
check
=
""
declaration
=
"""
PyArrayObject*
%(name)
s
;
"""
%
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
]
)
PyArrayObject*
{name}
;
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
])
)
return
declaration
+
check
def
c_init
(
self
,
name
,
sub
):
return
"""
%(name)
s
= NULL;
"""
%
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]
)
{name}
= NULL;
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
])
)
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
if
check_input
:
check
=
"""
%(name)
s
= NULL;
if (py_
%(name)
s == Py_None)
{
// We can either fail here or set
%(name)
s
to NULL and rely on Ops
{name}
= NULL;
if (py_
{name} == Py_None) {
{
// We can either fail here or set
{name}
to NULL and rely on Ops
// using tensors to handle the NULL case, but if they fail to do so
// they'll end up with nasty segfaults, so this is public service.
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None");
%(fail)
s
}
if (!PyArray_Check(py_
%(name)
s))
{
{fail}
}
}
if (!PyArray_Check(py_
{name})) {
{
PyErr_SetString(PyExc_ValueError, "expected an ndarray");
%(fail)
s
}
// We expect
%(type_num)
s
if (!PyArray_ISALIGNED((PyArrayObject*) py_
%(name)
s))
{
PyArrayObject * tmp = (PyArrayObject*) py_
%(name)
s
;
{fail}
}
}
// We expect
{type_num}
if (!PyArray_ISALIGNED((PyArrayObject*) py_
{name})) {
{
PyArrayObject * tmp = (PyArrayObject*) py_
{name}
;
PyErr_Format(PyExc_NotImplementedError,
"expected an aligned array of type
%
%
ld "
"(
%(type_num)
s), got non-aligned array of type
%
%
ld"
" with
%
%
ld dimensions, with 3 last dims "
"
%
%
ld,
%%
ld,
%
%
ld"
" and 3 last strides
%
%
ld
%%
ld,
%
%
ld.",
(long int)
%(type_num)
s
,
(long int) PyArray_TYPE((PyArrayObject*) py_
%(name)
s
),
"expected an aligned array of type
%
ld "
"(
{type_num}), got non-aligned array of type
%
ld"
" with
%
ld dimensions, with 3 last dims "
"
%
ld,
%
ld,
%
ld"
" and 3 last strides
%
ld
%
ld,
%
ld.",
(long int)
{type_num}
,
(long int) PyArray_TYPE((PyArrayObject*) py_
{name}
),
(long int) PyArray_NDIM(tmp),
(long int) (PyArray_NDIM(tmp) >= 3 ?
PyArray_DIMS(tmp)[PyArray_NDIM(tmp)-3] : -1),
...
...
@@ -531,74 +531,73 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
(long int) (PyArray_NDIM(tmp) >= 1 ?
PyArray_STRIDES(tmp)[PyArray_NDIM(tmp)-1] : -1)
);
%(fail)
s
}
{fail}
}
}
// This is a TypeError to be consistent with DEBUG_MODE
// Note: DEBUG_MODE also tells the name of the container
if (PyArray_TYPE((PyArrayObject*) py_
%(name)
s) !=
%(type_num)
s)
{
if (PyArray_TYPE((PyArrayObject*) py_
{name}) != {type_num}) {
{
PyErr_Format(PyExc_TypeError,
"expected type_num
%
%
d (
%(type_num)
s) got
%
%
d",
%(type_num)
s, PyArray_TYPE((PyArrayObject*) py_
%(name)
s
));
%(fail)
s
}
"""
%
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]
)
"expected type_num
%
d ({type_num}) got
%
d",
{type_num}, PyArray_TYPE((PyArrayObject*) py_{name}
));
{fail}
}
}
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
])
)
else
:
check
=
""
return
(
check
+
"""
%(name)
s = (PyArrayObject*)(py_
%(name)
s);
Py_XINCREF(
%(name)
s);
"""
%
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
])
{name} = (PyArrayObject*)(py_{name});
Py_XINCREF({name});
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]))
)
def
c_cleanup
(
self
,
name
,
sub
):
return
"""
if (
%(name)
s)
{
Py_XDECREF(
%(name)
s
);
}
"""
%
locals
(
)
if (
{name}) {
{
Py_XDECREF(
{name}
);
}
}
"""
.
format
(
**
locals
()
)
def
c_sync
(
self
,
name
,
sub
):
fail
=
sub
[
"fail"
]
type_num
=
self
.
dtype_specs
()[
2
]
return
"""
{
Py_XDECREF(py_
%(name)
s);
}
if (!
%(name)
s)
{
{
{Py_XDECREF(py_{name});}
}
if (!
{name}) {
{
Py_INCREF(Py_None);
py_
%(name)
s
= Py_None;
}
else if ((void*)py_
%(name)
s != (void*)
%(name)
s)
{
py_
%(name)
s = (PyObject*)
%(name)
s
;
}
py_
{name}
= Py_None;
}
}
else if ((void*)py_
{name} != (void*){name}) {
{
py_
{name} = (PyObject*){name}
;
}
}
{
Py_XINCREF(py_
%(name)
s);
}
{
{Py_XINCREF(py_{name});}
}
if (
%(name)
s && !PyArray_ISALIGNED((PyArrayObject*) py_
%(name)
s))
{
if (
{name} && !PyArray_ISALIGNED((PyArrayObject*) py_{name})) {
{
PyErr_Format(PyExc_NotImplementedError,
"c_sync: expected an aligned array, got non-aligned array of type
%
%
ld"
" with
%
%
ld dimensions, with 3 last dims "
"
%
%
ld,
%%
ld,
%
%
ld"
" and 3 last strides
%
%
ld
%%
ld,
%
%
ld.",
(long int) PyArray_TYPE((PyArrayObject*) py_
%(name)
s
),
(long int) PyArray_NDIM(
%(name)
s
),
(long int) (PyArray_NDIM(
%(name)
s
) >= 3 ?
PyArray_DIMS(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-3] : -1),
(long int) (PyArray_NDIM(
%(name)
s
) >= 2 ?
PyArray_DIMS(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-2] : -1),
(long int) (PyArray_NDIM(
%(name)
s
) >= 1 ?
PyArray_DIMS(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-1] : -1),
(long int) (PyArray_NDIM(
%(name)
s
) >= 3 ?
PyArray_STRIDES(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-3] : -1),
(long int) (PyArray_NDIM(
%(name)
s
) >= 2 ?
PyArray_STRIDES(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-2] : -1),
(long int) (PyArray_NDIM(
%(name)
s
) >= 1 ?
PyArray_STRIDES(
%(name)
s)[PyArray_NDIM(
%(name)
s
)-1] : -1)
"c_sync: expected an aligned array, got non-aligned array of type
%
ld"
" with
%
ld dimensions, with 3 last dims "
"
%
ld,
%
ld,
%
ld"
" and 3 last strides
%
ld
%
ld,
%
ld.",
(long int) PyArray_TYPE((PyArrayObject*) py_
{name}
),
(long int) PyArray_NDIM(
{name}
),
(long int) (PyArray_NDIM(
{name}
) >= 3 ?
PyArray_DIMS(
{name})[PyArray_NDIM({name}
)-3] : -1),
(long int) (PyArray_NDIM(
{name}
) >= 2 ?
PyArray_DIMS(
{name})[PyArray_NDIM({name}
)-2] : -1),
(long int) (PyArray_NDIM(
{name}
) >= 1 ?
PyArray_DIMS(
{name})[PyArray_NDIM({name}
)-1] : -1),
(long int) (PyArray_NDIM(
{name}
) >= 3 ?
PyArray_STRIDES(
{name})[PyArray_NDIM({name}
)-3] : -1),
(long int) (PyArray_NDIM(
{name}
) >= 2 ?
PyArray_STRIDES(
{name})[PyArray_NDIM({name}
)-2] : -1),
(long int) (PyArray_NDIM(
{name}
) >= 1 ?
PyArray_STRIDES(
{name})[PyArray_NDIM({name}
)-1] : -1)
);
%(fail)
s
}
"""
%
locals
(
)
{fail}
}
}
"""
.
format
(
**
locals
()
)
def
c_headers
(
self
,
**
kwargs
):
return
ps
.
get_scalar_type
(
self
.
dtype
)
.
c_headers
(
**
kwargs
)
...
...
pytensor/typed_list/basic.py
浏览文件 @
06c5acdf
...
...
@@ -103,12 +103,12 @@ class GetItem(COp):
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
return
"""
%(output_name)
s = (typeof
%(output_name)
s) PyList_GetItem( (PyObject*)
%(x_name)
s, *((npy_int64 *) PyArray_DATA(
%(index)
s
)));
if(
%(output_name)
s == NULL)
{
%(fail)
s
}
Py_INCREF(
%(output_name)
s
);
"""
%
locals
(
)
{output_name} = (typeof {output_name}) PyList_GetItem( (PyObject*) {x_name}, *((npy_int64 *) PyArray_DATA({index}
)));
if(
{output_name} == NULL){
{
{fail}
}
}
Py_INCREF(
{output_name}
);
"""
.
format
(
**
locals
()
)
def
c_code_cache_version
(
self
):
return
(
1
,)
...
...
@@ -170,8 +170,8 @@ class Append(COp):
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
init
=
"""
%(output_name)
s = (PyListObject*) PyList_GetSlice((PyObject*)
%(x_name)
s, 0, PyList_GET_SIZE((PyObject*)
%(x_name)
s
)) ;
"""
%
locals
(
)
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name}
)) ;
"""
.
format
(
**
locals
()
)
else
:
init
=
f
"""
{output_name} = {x_name};
...
...
@@ -179,15 +179,14 @@ class Append(COp):
return
(
init
+
"""
if(
%(output_name)
s==NULL){
%(fail)
s
};
if(PyList_Append( (PyObject*)
%(output_name)
s,(PyObject*)
%(toAppend)
s)){
%(fail)
s
};
Py_INCREF(
%(output_name)
s);
"""
%
locals
()
if({output_name}==NULL){{
{fail}
}};
if(PyList_Append( (PyObject*) {output_name},(PyObject*) {toAppend})){{
{fail}
}};
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
)
def
c_code_cache_version
(
self
):
...
...
@@ -252,8 +251,8 @@ class Extend(COp):
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
init
=
"""
%(output_name)
s = (PyListObject*) PyList_GetSlice((PyObject*)
%(x_name)
s, 0, PyList_GET_SIZE((PyObject*)
%(x_name)
s
)) ;
"""
%
locals
(
)
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name}
)) ;
"""
.
format
(
**
locals
()
)
else
:
init
=
f
"""
{output_name} = {x_name};
...
...
@@ -262,18 +261,17 @@ class Extend(COp):
init
+
"""
int i =0;
int length = PyList_GET_SIZE((PyObject*)
%(toAppend)
s);
if(
%(output_name)
s==NULL){
%(fail)
s
};
for(i; i < length; i++){
if(PyList_Append( (PyObject*)
%(output_name)
s,(PyObject*) PyList_GetItem((PyObject*)
%(toAppend)
s,i))==-1){
%(fail)
s
};
}
Py_INCREF(
%(output_name)
s);
"""
%
locals
()
int length = PyList_GET_SIZE((PyObject*) {toAppend});
if({output_name}==NULL){{
{fail}
}};
for(i; i < length; i++){{
if(PyList_Append( (PyObject*) {output_name},(PyObject*) PyList_GetItem((PyObject*) {toAppend},i))==-1){{
{fail}
}};
}}
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
)
def
c_code_cache_version_
(
self
):
...
...
@@ -341,8 +339,8 @@ class Insert(COp):
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
init
=
"""
%(output_name)
s = (PyListObject*) PyList_GetSlice((PyObject*)
%(x_name)
s, 0, PyList_GET_SIZE((PyObject*)
%(x_name)
s
)) ;
"""
%
locals
(
)
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name}
)) ;
"""
.
format
(
**
locals
()
)
else
:
init
=
f
"""
{output_name} = {x_name};
...
...
@@ -350,15 +348,14 @@ class Insert(COp):
return
(
init
+
"""
if(
%(output_name)
s==NULL){
%(fail)
s
};
if(PyList_Insert((PyObject*)
%(output_name)
s, *((npy_int64 *) PyArray_DATA(
%(index)
s)), (PyObject*)
%(toInsert)
s)==-1){
%(fail)
s
};
Py_INCREF(
%(output_name)
s);
"""
%
locals
()
if({output_name}==NULL){{
{fail}
}};
if(PyList_Insert((PyObject*) {output_name}, *((npy_int64 *) PyArray_DATA({index})), (PyObject*) {toInsert})==-1){{
{fail}
}};
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
)
def
c_code_cache_version
(
self
):
...
...
@@ -470,8 +467,8 @@ class Reverse(COp):
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
init
=
"""
%(output_name)
s = (PyListObject*) PyList_GetSlice((PyObject*)
%(x_name)
s, 0, PyList_GET_SIZE((PyObject*)
%(x_name)
s
)) ;
"""
%
locals
(
)
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name}
)) ;
"""
.
format
(
**
locals
()
)
else
:
init
=
f
"""
{output_name} = {x_name};
...
...
@@ -479,15 +476,14 @@ class Reverse(COp):
return
(
init
+
"""
if(
%(output_name)
s==NULL){
%(fail)
s
};
if(PyList_Reverse((PyObject*)
%(output_name)
s)==-1){
%(fail)
s
};
Py_INCREF(
%(output_name)
s);
"""
%
locals
()
if({output_name}==NULL){{
{fail}
}};
if(PyList_Reverse((PyObject*) {output_name})==-1){{
{fail}
}};
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
)
def
c_code_cache_version
(
self
):
...
...
@@ -602,11 +598,11 @@ class Length(COp):
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
return
"""
if(!
%(output_name)
s
)
%(output_name)
s
=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA(
%(output_name)
s))[0]=PyList_Size((PyObject*)
%(x_name)
s
);
Py_INCREF(
%(output_name)
s
);
"""
%
locals
(
)
if(!
{output_name}
)
{output_name}
=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA(
{output_name}))[0]=PyList_Size((PyObject*){x_name}
);
Py_INCREF(
{output_name}
);
"""
.
format
(
**
locals
()
)
def
c_code_cache_version
(
self
):
return
(
1
,)
...
...
pytensor/typed_list/type.py
浏览文件 @
06c5acdf
...
...
@@ -111,26 +111,25 @@ class TypedListType(CType):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
if
check_input
:
pre
=
"""
if (!PyList_Check(py_
%(name)
s))
{
if (!PyList_Check(py_
{name})) {
{
PyErr_SetString(PyExc_TypeError, "expected a list");
%(fail)
s
}
"""
%
dict
(
name
=
name
,
fail
=
sub
[
"fail"
]
)
{fail}
}
}"""
.
format
(
**
dict
(
name
=
name
,
fail
=
sub
[
"fail"
])
)
else
:
pre
=
""
return
(
pre
+
"""
%(name)
s = (PyListObject*) (py_
%(name)
s);
"""
%
dict
(
name
=
name
,
fail
=
sub
[
"fail"
])
{name} = (PyListObject*) (py_{name});
"""
.
format
(
**
dict
(
name
=
name
,
fail
=
sub
[
"fail"
]))
)
def
c_sync
(
self
,
name
,
sub
):
return
"""
Py_XDECREF(py_
%(name)
s
);
py_
%(name)
s = (PyObject*)(
%(name)
s
);
Py_INCREF(py_
%(name)
s
);
"""
%
dict
(
name
=
name
)
Py_XDECREF(py_
{name}
);
py_
{name} = (PyObject*)({name}
);
Py_INCREF(py_
{name}
);
"""
.
format
(
**
dict
(
name
=
name
)
)
def
c_cleanup
(
self
,
name
,
sub
):
return
""
...
...
tests/compile/test_debugmode.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/link/c/test_basic.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/link/c/test_op.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/link/c/test_params_type.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/link/c/test_type.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/tensor/conv/c_conv3d_corr3d_ref.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/tensor/conv/c_conv_corr_ref.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
tests/tensor/utils.py
浏览文件 @
06c5acdf
差异被折叠。
点击展开。
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论