Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
55b2f4fa
提交
55b2f4fa
authored
7月 07, 2024
作者:
Virgile Andreani
提交者:
Virgile Andreani
7月 12, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace str.format with f-strings
上级
e2f9cb8e
隐藏空白字符变更
内嵌
并排
正在显示
39 个修改的文件
包含
361 行增加
和
462 行删除
+361
-462
builders.py
pytensor/compile/builders.py
+1
-1
types.py
pytensor/compile/function/types.py
+1
-5
basic.py
pytensor/graph/rewriting/basic.py
+5
-8
utils.py
pytensor/graph/utils.py
+3
-4
basic.py
pytensor/link/c/basic.py
+18
-22
cmodule.py
pytensor/link/c/cmodule.py
+2
-2
interface.py
pytensor/link/c/interface.py
+4
-8
op.py
pytensor/link/c/op.py
+4
-13
params_type.py
pytensor/link/c/params_type.py
+26
-42
type.py
pytensor/link/c/type.py
+35
-47
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+4
-5
utils.py
pytensor/link/utils.py
+4
-8
printing.py
pytensor/printing.py
+2
-4
basic.py
pytensor/scalar/basic.py
+15
-17
math.py
pytensor/scalar/math.py
+10
-10
utils.py
pytensor/scan/utils.py
+2
-1
basic.py
pytensor/sparse/basic.py
+6
-4
rewriting.py
pytensor/sparse/rewriting.py
+29
-20
basic.py
pytensor/tensor/basic.py
+18
-24
blas.py
pytensor/tensor/blas.py
+11
-18
blas_c.py
pytensor/tensor/blas_c.py
+2
-2
blas_headers.py
pytensor/tensor/blas_headers.py
+2
-2
elemwise.py
pytensor/tensor/elemwise.py
+19
-24
elemwise_cgen.py
pytensor/tensor/elemwise_cgen.py
+4
-2
extra_ops.py
pytensor/tensor/extra_ops.py
+8
-9
math.py
pytensor/tensor/math.py
+4
-3
subtensor.py
pytensor/tensor/subtensor.py
+22
-30
type.py
pytensor/tensor/type.py
+17
-15
basic.py
pytensor/typed_list/basic.py
+20
-21
type.py
pytensor/typed_list/type.py
+7
-6
test_debugmode.py
tests/compile/test_debugmode.py
+6
-4
test_basic.py
tests/link/c/test_basic.py
+3
-2
test_cmodule.py
tests/link/c/test_cmodule.py
+2
-4
test_op.py
tests/link/c/test_op.py
+3
-3
test_params_type.py
tests/link/c/test_params_type.py
+10
-16
test_type.py
tests/link/c/test_type.py
+22
-29
c_conv3d_corr3d_ref.py
tests/tensor/conv/c_conv3d_corr3d_ref.py
+4
-13
c_conv_corr_ref.py
tests/tensor/conv/c_conv_corr_ref.py
+4
-12
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+2
-2
没有找到文件。
pytensor/compile/builders.py
浏览文件 @
55b2f4fa
...
@@ -434,7 +434,7 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -434,7 +434,7 @@ class OpFromGraph(Op, HasInnerGraph):
def
__str__
(
self
):
def
__str__
(
self
):
name
=
self
.
__class__
.
__name__
if
self
.
name
is
None
else
self
.
name
name
=
self
.
__class__
.
__name__
if
self
.
name
is
None
else
self
.
name
is_inline
=
self
.
is_inline
is_inline
=
self
.
is_inline
return
"{name}{{inline={is_inline}}}"
.
format
(
**
locals
())
return
f
"{name}{{inline={is_inline}}}"
def
_combine_list_overrides
(
self
,
default_outs
,
custom_outs
,
callable_args
):
def
_combine_list_overrides
(
self
,
default_outs
,
custom_outs
,
callable_args
):
"""Combines default and custom overrides into a single list of outputs."""
"""Combines default and custom overrides into a single list of outputs."""
...
...
pytensor/compile/function/types.py
浏览文件 @
55b2f4fa
...
@@ -1890,11 +1890,7 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs):
...
@@ -1890,11 +1890,7 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs):
)
)
else
:
else
:
if
n_unnamed_inputs
==
0
:
if
n_unnamed_inputs
==
0
:
msg
=
"The function has {} named input{} ({})."
.
format
(
msg
=
f
"The function has {n_named_inputs} named input{get_plural(n_named_inputs)} ({', '.join(named_inputs)})."
n_named_inputs
,
get_plural
(
n_named_inputs
),
", "
.
join
(
named_inputs
),
)
else
:
else
:
msg
=
(
msg
=
(
f
"The function has {n_named_inputs} named input{get_plural(n_named_inputs)} ({', '.join(named_inputs)}), and {n_unnamed_inputs} unnamed "
f
"The function has {n_named_inputs} named input{get_plural(n_named_inputs)} ({', '.join(named_inputs)}), and {n_unnamed_inputs} unnamed "
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
55b2f4fa
...
@@ -1664,15 +1664,12 @@ class PatternNodeRewriter(NodeRewriter):
...
@@ -1664,15 +1664,12 @@ class PatternNodeRewriter(NodeRewriter):
def
pattern_to_str
(
pattern
):
def
pattern_to_str
(
pattern
):
if
isinstance
(
pattern
,
list
|
tuple
):
if
isinstance
(
pattern
,
list
|
tuple
):
return
"{}({})"
.
format
(
args
=
", "
.
join
(
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:])
str
(
pattern
[
0
]),
return
f
"{pattern[0]!s}({args})"
", "
.
join
(
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]),
)
elif
isinstance
(
pattern
,
dict
):
elif
isinstance
(
pattern
,
dict
):
return
"{} subject to {}"
.
format
(
a
=
pattern_to_str
(
pattern
[
"pattern"
])
pattern_to_str
(
pattern
[
"pattern"
]),
b
=
pattern
.
get
(
"constraint"
,
"no conditions"
)
str
(
pattern
.
get
(
"constraint"
,
"no conditions"
)),
return
f
"{a} subject to {b}"
)
else
:
else
:
return
str
(
pattern
)
return
str
(
pattern
)
...
...
pytensor/graph/utils.py
浏览文件 @
55b2f4fa
...
@@ -245,10 +245,9 @@ class MetaType(ABCMeta):
...
@@ -245,10 +245,9 @@ class MetaType(ABCMeta):
else
:
else
:
def
__str__
(
self
):
def
__str__
(
self
):
return
"{}{{{}}}"
.
format
(
classname
=
self
.
__class__
.
__name__
self
.
__class__
.
__name__
,
args
=
", "
.
join
(
f
"{p}={getattr(self, p)!r}"
for
p
in
props
)
", "
.
join
(
f
"{p}={getattr(self, p)!r}"
for
p
in
props
),
return
f
"{classname}{{{args}}}"
)
dct
[
"__str__"
]
=
__str__
dct
[
"__str__"
]
=
__str__
...
...
pytensor/link/c/basic.py
浏览文件 @
55b2f4fa
...
@@ -219,7 +219,6 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -219,7 +219,6 @@ def struct_gen(args, struct_builders, blocks, sub):
"""
"""
struct_decl
=
""
struct_decl
=
""
struct_init_head
=
""
struct_init_head
=
""
struct_init_tail
=
""
struct_cleanup
=
""
struct_cleanup
=
""
for
block
in
struct_builders
:
for
block
in
struct_builders
:
...
@@ -243,7 +242,6 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -243,7 +242,6 @@ def struct_gen(args, struct_builders, blocks, sub):
# decrements the storage's refcount in the destructor
# decrements the storage's refcount in the destructor
storage_decref
=
"
\n
"
.
join
(
f
"Py_XDECREF(this->{arg});"
for
arg
in
args
)
storage_decref
=
"
\n
"
.
join
(
f
"Py_XDECREF(this->{arg});"
for
arg
in
args
)
args_names
=
", "
.
join
(
args
)
args_decl
=
", "
.
join
(
f
"PyObject* {arg}"
for
arg
in
args
)
args_decl
=
", "
.
join
(
f
"PyObject* {arg}"
for
arg
in
args
)
# The following code stores the exception data in __ERROR, which
# The following code stores the exception data in __ERROR, which
...
@@ -251,7 +249,8 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -251,7 +249,8 @@ def struct_gen(args, struct_builders, blocks, sub):
# that holds the type, the value and the traceback. After storing
# that holds the type, the value and the traceback. After storing
# the error, we return the failure code so we know which code
# the error, we return the failure code so we know which code
# block failed.
# block failed.
do_return
=
"""
failure_var
=
sub
[
"failure_var"
]
do_return
=
f
"""
if ({failure_var}) {{
if ({failure_var}) {{
// When there is a failure, this code puts the exception
// When there is a failure, this code puts the exception
// in __ERROR.
// in __ERROR.
...
@@ -274,15 +273,13 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -274,15 +273,13 @@ def struct_gen(args, struct_builders, blocks, sub):
}}
}}
// The failure code is returned to index what code block failed.
// The failure code is returned to index what code block failed.
return {failure_var};
return {failure_var};
"""
.
format
(
**
sub
)
"""
sub
=
dict
(
sub
)
sub
.
update
(
locals
())
# TODO: add some error checking to make sure storage_<x> are
# TODO: add some error checking to make sure storage_<x> are
# 1-element lists and __ERROR is a 3-elements list.
# 1-element lists and __ERROR is a 3-elements list.
struct_code
=
"""
name
=
sub
[
"name"
]
struct_code
=
f
"""
namespace {{
namespace {{
struct {name} {{
struct {name} {{
PyObject* __ERROR;
PyObject* __ERROR;
...
@@ -326,7 +323,7 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -326,7 +323,7 @@ def struct_gen(args, struct_builders, blocks, sub):
}}
}}
}};
}};
}}
}}
"""
.
format
(
**
sub
)
"""
return
struct_code
return
struct_code
...
@@ -370,13 +367,10 @@ def get_c_init(fgraph, r, name, sub):
...
@@ -370,13 +367,10 @@ def get_c_init(fgraph, r, name, sub):
Wrapper around c_init that initializes py_name to Py_None.
Wrapper around c_init that initializes py_name to Py_None.
"""
"""
pre
=
(
pre
=
f
"""
""
"""
py_{name} = Py_None;
py_{name} = Py_None;
{{Py_XINCREF(py_{name});}}
{{Py_XINCREF(py_{name});}}
"""
.
format
(
**
locals
())
"""
)
return
pre
+
r
.
type
.
c_init
(
name
,
sub
)
return
pre
+
r
.
type
.
c_init
(
name
,
sub
)
...
@@ -410,10 +404,10 @@ def get_c_extract(fgraph, r, name, sub):
...
@@ -410,10 +404,10 @@ def get_c_extract(fgraph, r, name, sub):
else
:
else
:
c_extract
=
r
.
type
.
c_extract
(
name
,
sub
,
False
)
c_extract
=
r
.
type
.
c_extract
(
name
,
sub
,
False
)
pre
=
"""
pre
=
f
"""
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
{{Py_XINCREF(py_{name});}}
"""
.
format
(
**
locals
())
"""
return
pre
+
c_extract
return
pre
+
c_extract
...
@@ -439,10 +433,10 @@ def get_c_extract_out(fgraph, r, name, sub):
...
@@ -439,10 +433,10 @@ def get_c_extract_out(fgraph, r, name, sub):
else
:
else
:
c_extract
=
r
.
type
.
c_extract_out
(
name
,
sub
,
check_input
,
check_broadcast
=
False
)
c_extract
=
r
.
type
.
c_extract_out
(
name
,
sub
,
check_input
,
check_broadcast
=
False
)
pre
=
"""
pre
=
f
"""
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
{{Py_XINCREF(py_{name});}}
"""
.
format
(
**
locals
())
"""
return
pre
+
c_extract
return
pre
+
c_extract
...
@@ -451,9 +445,9 @@ def get_c_cleanup(fgraph, r, name, sub):
...
@@ -451,9 +445,9 @@ def get_c_cleanup(fgraph, r, name, sub):
Wrapper around c_cleanup that decrefs py_name.
Wrapper around c_cleanup that decrefs py_name.
"""
"""
post
=
"""
post
=
f
"""
{{Py_XDECREF(py_{name});}}
{{Py_XDECREF(py_{name});}}
"""
.
format
(
**
locals
())
"""
return
r
.
type
.
c_cleanup
(
name
,
sub
)
+
post
return
r
.
type
.
c_cleanup
(
name
,
sub
)
+
post
...
@@ -462,7 +456,9 @@ def get_c_sync(fgraph, r, name, sub):
...
@@ -462,7 +456,9 @@ def get_c_sync(fgraph, r, name, sub):
Wrapper around c_sync that syncs py_name with storage.
Wrapper around c_sync that syncs py_name with storage.
"""
"""
return
"""
failure_var
=
sub
[
"failure_var"
]
sync
=
r
.
type
.
c_sync
(
name
,
sub
)
return
f
"""
if (!{failure_var}) {{
if (!{failure_var}) {{
{sync}
{sync}
PyObject* old = PyList_GET_ITEM(storage_{name}, 0);
PyObject* old = PyList_GET_ITEM(storage_{name}, 0);
...
@@ -470,7 +466,7 @@ def get_c_sync(fgraph, r, name, sub):
...
@@ -470,7 +466,7 @@ def get_c_sync(fgraph, r, name, sub):
PyList_SET_ITEM(storage_{name}, 0, py_{name});
PyList_SET_ITEM(storage_{name}, 0, py_{name});
{{Py_XDECREF(old);}}
{{Py_XDECREF(old);}}
}}
}}
"""
.
format
(
**
dict
(
sync
=
r
.
type
.
c_sync
(
name
,
sub
),
name
=
name
,
**
sub
))
"""
def
apply_policy
(
fgraph
,
policy
,
r
,
name
,
sub
):
def
apply_policy
(
fgraph
,
policy
,
r
,
name
,
sub
):
...
...
pytensor/link/c/cmodule.py
浏览文件 @
55b2f4fa
...
@@ -1966,14 +1966,14 @@ class Compiler:
...
@@ -1966,14 +1966,14 @@ class Compiler:
return
False
return
False
code
=
(
code
=
(
"""
f
"""
{preamble}
{preamble}
int main(int argc, char** argv)
int main(int argc, char** argv)
{{
{{
{body}
{body}
return 0;
return 0;
}}
}}
"""
.
format
(
**
locals
())
"""
)
.
encode
()
)
.
encode
()
return
cls
.
_try_compile_tmp
(
return
cls
.
_try_compile_tmp
(
code
,
code
,
...
...
pytensor/link/c/interface.py
浏览文件 @
55b2f4fa
...
@@ -558,7 +558,9 @@ class CLinkerType(CLinkerObject):
...
@@ -558,7 +558,9 @@ class CLinkerType(CLinkerObject):
uninitialized.
uninitialized.
"""
"""
return
"""
c_init_code
=
self
.
c_init
(
name
,
sub
)
c_extract_code
=
self
.
c_extract
(
name
,
sub
,
check_input
)
return
f
"""
if (py_{name} == Py_None)
if (py_{name} == Py_None)
{{
{{
{c_init_code}
{c_init_code}
...
@@ -567,13 +569,7 @@ class CLinkerType(CLinkerObject):
...
@@ -567,13 +569,7 @@ class CLinkerType(CLinkerObject):
{{
{{
{c_extract_code}
{c_extract_code}
}}
}}
"""
.
format
(
"""
**
dict
(
name
=
name
,
c_init_code
=
self
.
c_init
(
name
,
sub
),
c_extract_code
=
self
.
c_extract
(
name
,
sub
,
check_input
),
)
)
def
c_cleanup
(
self
,
name
:
str
,
sub
:
dict
[
str
,
str
])
->
str
:
def
c_cleanup
(
self
,
name
:
str
,
sub
:
dict
[
str
,
str
])
->
str
:
"""Return C code to clean up after :meth:`CLinkerType.c_extract`.
"""Return C code to clean up after :meth:`CLinkerType.c_extract`.
...
...
pytensor/link/c/op.py
浏览文件 @
55b2f4fa
...
@@ -587,24 +587,15 @@ class ExternalCOp(COp):
...
@@ -587,24 +587,15 @@ class ExternalCOp(COp):
params
=
f
", {sub['params']}"
params
=
f
", {sub['params']}"
# Generate the C code
# Generate the C code
return
"""
return
f
"""
{define_macros}
{define_macros}
{{
{{
if ({
func_name}({func_args
}{params}) != 0) {{
if ({
self.func_name}({self.format_c_function_args(inp, out)
}{params}) != 0) {{
{
fail
}
{
sub['fail']
}
}}
}}
}}
}}
{undef_macros}
{undef_macros}
"""
.
format
(
"""
**
dict
(
func_name
=
self
.
func_name
,
fail
=
sub
[
"fail"
],
params
=
params
,
func_args
=
self
.
format_c_function_args
(
inp
,
out
),
define_macros
=
define_macros
,
undef_macros
=
undef_macros
,
)
)
else
:
else
:
if
"code"
in
self
.
code_sections
:
if
"code"
in
self
.
code_sections
:
op_code
=
self
.
code_sections
[
"code"
]
op_code
=
self
.
code_sections
[
"code"
]
...
...
pytensor/link/c/params_type.py
浏览文件 @
55b2f4fa
...
@@ -262,9 +262,10 @@ class Params(dict):
...
@@ -262,9 +262,10 @@ class Params(dict):
self
.
__dict__
.
update
(
__params_type__
=
params_type
,
__signatures__
=
None
)
self
.
__dict__
.
update
(
__params_type__
=
params_type
,
__signatures__
=
None
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"Params({})"
.
format
(
args
=
", "
.
join
(
", "
.
join
((
f
"{k}:{type(self[k]).__name__}:{self[k]}"
)
for
k
in
sorted
(
self
)
)
(
f
"{k}:{type(self[k]).__name__}:{self[k]}"
)
for
k
in
sorted
(
self
)
)
)
return
f
"Params({args})"
def
__getattr__
(
self
,
key
):
def
__getattr__
(
self
,
key
):
if
key
not
in
self
:
if
key
not
in
self
:
...
@@ -422,9 +423,10 @@ class ParamsType(CType):
...
@@ -422,9 +423,10 @@ class ParamsType(CType):
return
super
()
.
__getattr__
(
self
,
key
)
return
super
()
.
__getattr__
(
self
,
key
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"ParamsType<{}>"
.
format
(
args
=
", "
.
join
(
", "
.
join
((
f
"{self.fields[i]}:{self.types[i]}"
)
for
i
in
range
(
self
.
length
)
)
f
"{self.fields[i]}:{self.types[i]}"
for
i
in
range
(
self
.
length
)
)
)
return
f
"ParamsType<{args}>"
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
return
(
...
@@ -730,11 +732,15 @@ class ParamsType(CType):
...
@@ -730,11 +732,15 @@ class ParamsType(CType):
struct_init
=
"
\n
"
.
join
(
c_init_list
)
struct_init
=
"
\n
"
.
join
(
c_init_list
)
struct_cleanup
=
"
\n
"
.
join
(
c_cleanup_list
)
struct_cleanup
=
"
\n
"
.
join
(
c_cleanup_list
)
struct_extract
=
"
\n\n
"
.
join
(
c_extract_list
)
struct_extract
=
"
\n\n
"
.
join
(
c_extract_list
)
struct_extract_method
=
"""
args
=
"
\n
"
.
join
(
f
"case {i}: extract_{self.fields[i]}(object); break;"
for
i
in
range
(
self
.
length
)
)
struct_extract_method
=
f
"""
void extract(PyObject* object, int field_pos) {{
void extract(PyObject* object, int field_pos) {{
switch(field_pos) {{
switch(field_pos) {{
// Extraction cases.
// Extraction cases.
{}
{
args
}
// Default case.
// Default case.
default:
default:
PyErr_Format(PyExc_TypeError, "ParamsType: no extraction defined for a field
%
d.", field_pos);
PyErr_Format(PyExc_TypeError, "ParamsType: no extraction defined for a field
%
d.", field_pos);
...
@@ -742,13 +748,8 @@ class ParamsType(CType):
...
@@ -742,13 +748,8 @@ class ParamsType(CType):
break;
break;
}}
}}
}}
}}
"""
.
format
(
"""
"
\n
"
.
join
(
final_struct_code
=
f
"""
(
"case
%
d: extract_
%
s(object); break;"
%
(
i
,
self
.
fields
[
i
]))
for
i
in
range
(
self
.
length
)
)
)
final_struct_code
=
"""
/** ParamsType {struct_name} **/
/** ParamsType {struct_name} **/
#ifndef {struct_name_defined}
#ifndef {struct_name_defined}
#define {struct_name_defined}
#define {struct_name_defined}
...
@@ -790,17 +791,7 @@ class ParamsType(CType):
...
@@ -790,17 +791,7 @@ class ParamsType(CType):
}};
}};
#endif
#endif
/** End ParamsType {struct_name} **/
/** End ParamsType {struct_name} **/
"""
.
format
(
"""
**
dict
(
struct_name_defined
=
struct_name_defined
,
struct_name
=
struct_name
,
struct_declare
=
struct_declare
,
struct_init
=
struct_init
,
struct_cleanup
=
struct_cleanup
,
struct_extract
=
struct_extract
,
struct_extract_method
=
struct_extract_method
,
)
)
return
[
*
sorted
(
c_support_code_set
),
final_struct_code
]
return
[
*
sorted
(
c_support_code_set
),
final_struct_code
]
...
@@ -813,9 +804,9 @@ class ParamsType(CType):
...
@@ -813,9 +804,9 @@ class ParamsType(CType):
# pointers.
# pointers.
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
return
"""
return
f
"""
{s
truct_
name}* {name};
{s
elf.
name}* {name};
"""
.
format
(
**
dict
(
struct_name
=
self
.
name
,
name
=
name
))
"""
def
c_init
(
self
,
name
,
sub
):
def
c_init
(
self
,
name
,
sub
):
# NB: It seems c_init() is not called for an op param.
# NB: It seems c_init() is not called for an op param.
...
@@ -831,40 +822,33 @@ class ParamsType(CType):
...
@@ -831,40 +822,33 @@ class ParamsType(CType):
"""
"""
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
fields_list
=
", "
.
join
(
f
'"{x}"'
for
x
in
self
.
fields
)
return
f
"""
/* Seems c_init() is not called for a op param. So I call `new` here. */
/* Seems c_init() is not called for a op param. So I call `new` here. */
{name} = new {s
truct_
name};
{name} = new {s
elf.
name};
{{ // This need a separate namespace for Clinker
{{ // This need a separate namespace for Clinker
const char* fields[] = {{{fields_list}}};
const char* fields[] = {{{fields_list}}};
if (py_{name} == Py_None) {{
if (py_{name} == Py_None) {{
PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None.");
PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None.");
{
fail
}
{
sub['fail']
}
}}
}}
for (int i = 0; i < {length}; ++i) {{
for (int i = 0; i < {
self.
length}; ++i) {{
PyObject* o = PyDict_GetItemString(py_{name}, fields[i]);
PyObject* o = PyDict_GetItemString(py_{name}, fields[i]);
if (o == NULL) {{
if (o == NULL) {{
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute
\\
"
%
s
\\
" in object.", fields[i]);
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute
\\
"
%
s
\\
" in object.", fields[i]);
{
fail
}
{
sub['fail']
}
}}
}}
{name}->extract(o, i);
{name}->extract(o, i);
if ({name}->errorOccurred()) {{
if ({name}->errorOccurred()) {{
/* The extract code from attribute type should have already raised a Python exception,
/* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */
* so we just print the attribute name in stderr. */
fprintf(stderr, "
\\
nParamsType: error when extracting value for attribute
\\
"
%
s
\\
".
\\
n", fields[i]);
fprintf(stderr, "
\\
nParamsType: error when extracting value for attribute
\\
"
%
s
\\
".
\\
n", fields[i]);
{
fail
}
{
sub['fail']
}
}}
}}
}}
}}
}}
}}
"""
.
format
(
"""
**
dict
(
name
=
name
,
struct_name
=
self
.
name
,
length
=
self
.
length
,
fail
=
sub
[
"fail"
],
fields_list
=
'"{}"'
.
format
(
'", "'
.
join
(
self
.
fields
)),
)
)
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
# FIXME: Looks like we need to decrement a reference count our two.
# FIXME: Looks like we need to decrement a reference count our two.
...
...
pytensor/link/c/type.py
浏览文件 @
55b2f4fa
...
@@ -98,12 +98,12 @@ class Generic(CType, Singleton):
...
@@ -98,12 +98,12 @@ class Generic(CType, Singleton):
"""
"""
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
return
"""
return
f
"""
assert(py_{name}->ob_refcnt > 1);
assert(py_{name}->ob_refcnt > 1);
Py_DECREF(py_{name});
Py_DECREF(py_{name});
py_{name} = {name} ? {name} : Py_None;
py_{name} = {name} ? {name} : Py_None;
Py_INCREF(py_{name});
Py_INCREF(py_{name});
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,)
return
(
1
,)
...
@@ -190,18 +190,19 @@ class CDataType(CType[D]):
...
@@ -190,18 +190,19 @@ class CDataType(CType[D]):
return
data
return
data
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
return
"""
return
f
"""
{ctype} {name};
{
self.
ctype} {name};
"""
.
format
(
**
dict
(
ctype
=
self
.
ctype
,
name
=
name
))
"""
def
c_init
(
self
,
name
,
sub
):
def
c_init
(
self
,
name
,
sub
):
return
f
"{name} = NULL;"
return
f
"{name} = NULL;"
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
fail
=
sub
[
"fail"
]
{name} = ({ctype})PyCapsule_GetPointer(py_{name}, NULL);
return
f
"""
{name} = ({self.ctype})PyCapsule_GetPointer(py_{name}, NULL);
if ({name} == NULL) {fail}
if ({name} == NULL) {fail}
"""
.
format
(
**
dict
(
name
=
name
,
ctype
=
self
.
ctype
,
fail
=
sub
[
"fail"
]))
"""
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
freefunc
=
self
.
freefunc
freefunc
=
self
.
freefunc
...
@@ -478,11 +479,8 @@ class EnumType(CType, dict):
...
@@ -478,11 +479,8 @@ class EnumType(CType, dict):
names_to_aliases
=
{
constant_name
:
""
for
constant_name
in
self
}
names_to_aliases
=
{
constant_name
:
""
for
constant_name
in
self
}
for
alias
in
self
.
aliases
:
for
alias
in
self
.
aliases
:
names_to_aliases
[
self
.
aliases
[
alias
]]
=
f
"({alias})"
names_to_aliases
[
self
.
aliases
[
alias
]]
=
f
"({alias})"
return
"{}<{}>({})"
.
format
(
args
=
", "
.
join
(
f
"{k}{names_to_aliases[k]}:{self[k]}"
for
k
in
sorted
(
self
))
type
(
self
)
.
__name__
,
return
f
"{type(self).__name__}<{self.ctype}>({args})"
self
.
ctype
,
", "
.
join
(
f
"{k}{names_to_aliases[k]}:{self[k]}"
for
k
in
sorted
(
self
)),
)
def
__getattr__
(
self
,
key
):
def
__getattr__
(
self
,
key
):
if
key
in
self
:
if
key
in
self
:
...
@@ -575,33 +573,27 @@ class EnumType(CType, dict):
...
@@ -575,33 +573,27 @@ class EnumType(CType, dict):
This C function may be useful to retrieve some runtime information.
This C function may be useful to retrieve some runtime information.
It is available in C code when pytensor flag ``config.cmodule__debug`` is set to ``True``.
It is available in C code when pytensor flag ``config.cmodule__debug`` is set to ``True``.
"""
"""
return
"""
cases
=
""
.
join
(
f
"""
case {name}: sprintf(out, "{name}"); break;
"""
for
name
in
self
)
return
f
"""
#ifdef DEBUG
#ifdef DEBUG
int pytensor_enum_to_string_{
cname}({
ctype} in, char* out) {{
int pytensor_enum_to_string_{
self.cname}({self.
ctype} in, char* out) {{
int ret = 0;
int ret = 0;
switch(in) {{
switch(in) {{
{cases}
{cases}
default:
default:
PyErr_SetString(PyExc_ValueError, "{
classname
}: unknown enum value.");
PyErr_SetString(PyExc_ValueError, "{
type(self).__name__
}: unknown enum value.");
ret = -1;
ret = -1;
break;
break;
}}
}}
return ret;
return ret;
}}
}}
#endif
#endif
"""
.
format
(
"""
**
dict
(
cname
=
self
.
cname
,
ctype
=
self
.
ctype
,
classname
=
type
(
self
)
.
__name__
,
cases
=
""
.
join
(
"""
case {name}: sprintf(out, "{name}"); break;
"""
.
format
(
**
dict
(
name
=
name
))
for
name
in
self
),
)
)
def
c_support_code
(
self
,
**
kwargs
):
def
c_support_code
(
self
,
**
kwargs
):
return
(
return
(
...
@@ -625,16 +617,17 @@ class EnumType(CType, dict):
...
@@ -625,16 +617,17 @@ class EnumType(CType, dict):
return
""
return
""
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyInt_Check(py_{name})) {{
if (PyInt_Check(py_{name})) {{
{name} = ({ctype})PyInt_AsLong(py_{name});
{name} = ({
self.
ctype})PyInt_AsLong(py_{name});
}} else {{
}} else {{
{name} = ({ctype})PyFloat_AsDouble(py_{name});
{name} = ({
self.
ctype})PyFloat_AsDouble(py_{name});
}}
}}
if (PyErr_Occurred()) {{
if (PyErr_Occurred()) {{
{fail}
{fail}
}}
}}
"""
.
format
(
**
dict
(
ctype
=
self
.
ctype
,
name
=
name
,
fail
=
sub
[
"fail"
]))
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
2
,
self
.
ctype
,
self
.
cname
,
tuple
(
self
.
items
()))
return
(
2
,
self
.
ctype
,
self
.
cname
,
tuple
(
self
.
items
()))
...
@@ -754,7 +747,14 @@ class CEnumType(EnumList):
...
@@ -754,7 +747,14 @@ class CEnumType(EnumList):
swapped_dict
=
{
v
:
k
for
(
k
,
v
)
in
self
.
items
()}
swapped_dict
=
{
v
:
k
for
(
k
,
v
)
in
self
.
items
()}
# swapped_dict's keys are integers.
# swapped_dict's keys are integers.
return
"""
fail
=
sub
[
"fail"
]
cases
=
""
.
join
(
f
"""
case {i}: {name} = {swapped_dict[i]}; break;
"""
for
i
in
sorted
(
swapped_dict
)
)
return
f
"""
switch(PyInt_AsLong(py_{name})) {{
switch(PyInt_AsLong(py_{name})) {{
{cases}
{cases}
default:
default:
...
@@ -762,19 +762,7 @@ class CEnumType(EnumList):
...
@@ -762,19 +762,7 @@ class CEnumType(EnumList):
{{{fail}}}
{{{fail}}}
break;
break;
}}
}}
"""
.
format
(
"""
**
dict
(
name
=
name
,
cases
=
""
.
join
(
"""
case
%(i)
d:
%(name)
s =
%(constant_cname)
s; break;
"""
%
dict
(
i
=
i
,
name
=
name
,
constant_cname
=
swapped_dict
[
i
])
for
i
in
sorted
(
swapped_dict
)
),
fail
=
sub
[
"fail"
],
)
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,
super
()
.
c_code_cache_version
())
return
(
1
,
super
()
.
c_code_cache_version
())
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
55b2f4fa
...
@@ -136,20 +136,19 @@ def _check_scipy_linalg_matrix(a, func_name):
...
@@ -136,20 +136,19 @@ def _check_scipy_linalg_matrix(a, func_name):
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
"""
"""
prefix
=
"scipy.linalg"
prefix
=
"scipy.linalg"
interp
=
(
prefix
,
func_name
)
# Unpack optional type
# Unpack optional type
if
isinstance
(
a
,
types
.
Optional
):
if
isinstance
(
a
,
types
.
Optional
):
a
=
a
.
type
a
=
a
.
type
if
not
isinstance
(
a
,
types
.
Array
):
if
not
isinstance
(
a
,
types
.
Array
):
msg
=
"{}.{}() only supported for array types"
.
format
(
*
interp
)
msg
=
f
"{prefix}.{func_name}() only supported for array types"
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
a
.
ndim
not
in
[
1
,
2
]:
if
a
.
ndim
not
in
[
1
,
2
]:
msg
=
"{}.{}() only supported on 1d or 2d arrays, found {}."
.
format
(
msg
=
(
*
interp
,
a
.
ndim
f
"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}."
)
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
not
isinstance
(
a
.
dtype
,
types
.
Float
|
types
.
Complex
):
if
not
isinstance
(
a
.
dtype
,
types
.
Float
|
types
.
Complex
):
msg
=
"{
}.{}() only supported on "
"float and complex arrays."
.
format
(
*
interp
)
msg
=
"{
prefix}.{func_name}() only supported on float and complex arrays."
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
...
...
pytensor/link/utils.py
浏览文件 @
55b2f4fa
...
@@ -355,14 +355,10 @@ def raise_with_op(
...
@@ -355,14 +355,10 @@ def raise_with_op(
+
f
"
\n
Inputs values: {scalar_values}"
+
f
"
\n
Inputs values: {scalar_values}"
)
)
if
verbosity
==
"high"
:
if
verbosity
==
"high"
:
detailed_err_msg
+=
"
\n
Inputs type_num: {}"
.
format
(
inpts
=
[
str
(
getattr
(
getattr
(
i
[
0
],
"dtype"
,
""
),
"num"
,
""
)
for
i
in
thunk
.
inputs
[
]
getattr
(
getattr
(
i
[
0
],
"dtype"
,
""
),
"num"
,
""
)
detailed_err_msg
+=
f
"
\n
Inputs type_num: {inpts}"
for
i
in
thunk
.
inputs
]
)
)
detailed_err_msg
+=
f
"
\n
Outputs clients: {clients}
\n
"
detailed_err_msg
+=
f
"
\n
Outputs clients: {clients}
\n
"
else
:
else
:
...
...
pytensor/printing.py
浏览文件 @
55b2f4fa
...
@@ -1048,10 +1048,8 @@ class DefaultPrinter(Printer):
...
@@ -1048,10 +1048,8 @@ class DefaultPrinter(Printer):
if
node
is
None
:
if
node
is
None
:
return
leaf_printer
.
process
(
output
,
pstate
)
return
leaf_printer
.
process
(
output
,
pstate
)
with
set_precedence
(
pstate
):
with
set_precedence
(
pstate
):
r
=
"{}({})"
.
format
(
args
=
", "
.
join
(
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
)
str
(
node
.
op
),
r
=
f
"{node.op}({args})"
", "
.
join
(
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
),
)
pstate
.
memo
[
output
]
=
r
pstate
.
memo
[
output
]
=
r
return
r
return
r
...
...
pytensor/scalar/basic.py
浏览文件 @
55b2f4fa
...
@@ -464,33 +464,32 @@ class ScalarType(CType, HasDataType, HasShape):
...
@@ -464,33 +464,32 @@ class ScalarType(CType, HasDataType, HasShape):
raise
NotImplementedError
(
"float16"
)
raise
NotImplementedError
(
"float16"
)
specs
=
self
.
dtype_specs
()
specs
=
self
.
dtype_specs
()
if
check_input
:
if
check_input
:
pre
=
"""
fail
=
sub
[
"fail"
]
dtype
=
specs
[
1
]
pyarr_type
=
f
"Py{specs[2]}ArrType_Type"
pre
=
f
"""
if (!PyObject_TypeCheck(py_{name}, &{pyarr_type}))
if (!PyObject_TypeCheck(py_{name}, &{pyarr_type}))
{{
{{
PyErr_Format(PyExc_ValueError,
PyErr_Format(PyExc_ValueError,
"Scalar check failed ({dtype})");
"Scalar check failed ({dtype})");
{fail}
{fail}
}}
}}
"""
.
format
(
"""
**
dict
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
pyarr_type
=
f
"Py{specs[2]}ArrType_Type"
,
)
)
else
:
else
:
pre
=
""
pre
=
""
return
(
return
(
pre
pre
+
"""
+
f
"""
PyArray_ScalarAsCtype(py_{name}, &{name});
PyArray_ScalarAsCtype(py_{name}, &{name});
"""
.
format
(
**
dict
(
sub
,
name
=
name
))
"""
)
)
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
specs
=
self
.
dtype_specs
()
specs
=
self
.
dtype_specs
()
return
"""
fail
=
sub
[
"fail"
]
dtype
=
specs
[
1
]
cls
=
specs
[
2
]
return
f
"""
Py_XDECREF(py_{name});
Py_XDECREF(py_{name});
py_{name} = PyArrayScalar_New({cls});
py_{name} = PyArrayScalar_New({cls});
if (!py_{name})
if (!py_{name})
...
@@ -502,7 +501,7 @@ class ScalarType(CType, HasDataType, HasShape):
...
@@ -502,7 +501,7 @@ class ScalarType(CType, HasDataType, HasShape):
{fail}
{fail}
}}
}}
PyArrayScalar_ASSIGN(py_{name}, {cls}, {name});
PyArrayScalar_ASSIGN(py_{name}, {cls}, {name});
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
cls
=
specs
[
2
]))
"""
def
c_cleanup
(
self
,
name
,
sub
):
def
c_cleanup
(
self
,
name
,
sub
):
return
""
return
""
...
@@ -1179,10 +1178,9 @@ class ScalarOp(COp):
...
@@ -1179,10 +1178,9 @@ class ScalarOp(COp):
not
in
(
"name"
,
"_op_use_c_code"
,
"bool"
,
"output_types_preference"
)
not
in
(
"name"
,
"_op_use_c_code"
,
"bool"
,
"output_types_preference"
)
]
]
if
param
:
if
param
:
return
"{}{{{}}}"
.
format
(
classname
=
self
.
__class__
.
__name__
self
.
__class__
.
__name__
,
args
=
", "
.
join
(
f
"{k}={v}"
for
k
,
v
in
param
)
", "
.
join
(
f
"{k}={v}"
for
k
,
v
in
param
),
return
f
"{classname}{{{args}}}"
)
else
:
else
:
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
...
...
pytensor/scalar/math.py
浏览文件 @
55b2f4fa
...
@@ -615,8 +615,8 @@ class Chi2SF(BinaryScalarOp):
...
@@ -615,8 +615,8 @@ class Chi2SF(BinaryScalarOp):
(
z
,)
=
out
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""{z} =
return
f
"""{z} =
({dtype}) 1 - GammaP({k}/2., {x}/2.);"""
.
format
(
**
locals
())
({dtype}) 1 - GammaP({k}/2., {x}/2.);"""
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp):
...
@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp):
(
z
,)
=
out
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""{z} =
return
f
"""{z} =
({dtype}) GammaP({k}, {x});"""
.
format
(
**
locals
())
({dtype}) GammaP({k}, {x});"""
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -717,8 +717,8 @@ class GammaIncC(BinaryScalarOp):
...
@@ -717,8 +717,8 @@ class GammaIncC(BinaryScalarOp):
(
z
,)
=
out
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""{z} =
return
f
"""{z} =
({dtype}) GammaQ({k}, {x});"""
.
format
(
**
locals
())
({dtype}) GammaQ({k}, {x});"""
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -1028,8 +1028,8 @@ class GammaU(BinaryScalarOp):
...
@@ -1028,8 +1028,8 @@ class GammaU(BinaryScalarOp):
(
z
,)
=
out
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""{z} =
return
f
"""{z} =
({dtype}) upperGamma({k}, {x});"""
.
format
(
**
locals
())
({dtype}) upperGamma({k}, {x});"""
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -1064,8 +1064,8 @@ class GammaL(BinaryScalarOp):
...
@@ -1064,8 +1064,8 @@ class GammaL(BinaryScalarOp):
(
z
,)
=
out
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""{z} =
return
f
"""{z} =
({dtype}) lowerGamma({k}, {x});"""
.
format
(
**
locals
())
({dtype}) lowerGamma({k}, {x});"""
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
...
pytensor/scan/utils.py
浏览文件 @
55b2f4fa
...
@@ -1072,7 +1072,8 @@ class ScanArgs:
...
@@ -1072,7 +1072,8 @@ class ScanArgs:
for
p
in
self
.
field_names
for
p
in
self
.
field_names
if
p
.
startswith
(
"outer_out"
)
if
p
.
startswith
(
"outer_out"
)
]
]
res
=
"ScanArgs(
\n
{})"
.
format
(
",
\n
"
.
join
(
inner_arg_strs
))
args
=
",
\n
"
.
join
(
inner_arg_strs
)
res
=
f
"ScanArgs(
\n
{args})"
return
res
return
res
def
__repr__
(
self
):
def
__repr__
(
self
):
...
...
pytensor/sparse/basic.py
浏览文件 @
55b2f4fa
...
@@ -3617,7 +3617,8 @@ class StructuredDotGradCSC(COp):
...
@@ -3617,7 +3617,8 @@ class StructuredDotGradCSC(COp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for g_ab"
)
raise
NotImplementedError
(
"Complex types are not supported for g_ab"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}}
if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}}
if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}}
if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}}
if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}}
if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}}
...
@@ -3689,7 +3690,7 @@ class StructuredDotGradCSC(COp):
...
@@ -3689,7 +3690,7 @@ class StructuredDotGradCSC(COp):
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
0
]]
return
[
shapes
[
0
]]
...
@@ -3750,7 +3751,8 @@ class StructuredDotGradCSR(COp):
...
@@ -3750,7 +3751,8 @@ class StructuredDotGradCSR(COp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for g_ab"
)
raise
NotImplementedError
(
"Complex types are not supported for g_ab"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}}
if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}}
if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}}
if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}}
if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}}
if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}}
...
@@ -3823,7 +3825,7 @@ class StructuredDotGradCSR(COp):
...
@@ -3823,7 +3825,7 @@ class StructuredDotGradCSR(COp):
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
0
]]
return
[
shapes
[
0
]]
...
...
pytensor/sparse/rewriting.py
浏览文件 @
55b2f4fa
...
@@ -137,7 +137,7 @@ class AddSD_ccode(_NoPythonCOp):
...
@@ -137,7 +137,7 @@ class AddSD_ccode(_NoPythonCOp):
inplace
=
int
(
self
.
inplace
)
inplace
=
int
(
self
.
inplace
)
format
=
{
"csc"
:
0
,
"csr"
:
1
}[
self
.
format
]
format
=
{
"csc"
:
0
,
"csr"
:
1
}[
self
.
format
]
out_typenum
=
node
.
outputs
[
0
]
.
type
.
dtype_specs
()[
2
]
out_typenum
=
node
.
outputs
[
0
]
.
type
.
dtype_specs
()[
2
]
code
=
"""
code
=
f
"""
Py_XDECREF({z});
Py_XDECREF({z});
if (!{inplace}){{
if (!{inplace}){{
if(PyArray_TYPE({y}) != {out_typenum}){{
if(PyArray_TYPE({y}) != {out_typenum}){{
...
@@ -179,7 +179,7 @@ class AddSD_ccode(_NoPythonCOp):
...
@@ -179,7 +179,7 @@ class AddSD_ccode(_NoPythonCOp):
}}
}}
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
return
code
return
code
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
...
@@ -310,7 +310,8 @@ class StructuredDotCSC(COp):
...
@@ -310,7 +310,8 @@ class StructuredDotCSC(COp):
typenum_a_val
=
node
.
inputs
[
0
]
.
type
.
dtype_specs
()[
2
]
# retrieve dtype number
typenum_a_val
=
node
.
inputs
[
0
]
.
type
.
dtype_specs
()[
2
]
# retrieve dtype number
typenum_b
=
node
.
inputs
[
4
]
.
type
.
dtype_specs
()[
2
]
# retrieve dtype number
typenum_b
=
node
.
inputs
[
4
]
.
type
.
dtype_specs
()[
2
]
# retrieve dtype number
rval
=
"""
fail
=
sub
[
"fail"
]
rval
=
f
"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
...
@@ -430,7 +431,7 @@ class StructuredDotCSC(COp):
...
@@ -430,7 +431,7 @@ class StructuredDotCSC(COp):
}}
}}
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
return
rval
return
rval
...
@@ -517,7 +518,8 @@ class StructuredDotCSR(COp):
...
@@ -517,7 +518,8 @@ class StructuredDotCSR(COp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b"
)
raise
NotImplementedError
(
"Complex types are not supported for b"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}}
if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}}
...
@@ -609,7 +611,7 @@ class StructuredDotCSR(COp):
...
@@ -609,7 +611,7 @@ class StructuredDotCSR(COp):
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
2
,)
return
(
2
,)
...
@@ -756,7 +758,8 @@ class UsmmCscDense(_NoPythonCOp):
...
@@ -756,7 +758,8 @@ class UsmmCscDense(_NoPythonCOp):
inplace
=
int
(
self
.
inplace
)
inplace
=
int
(
self
.
inplace
)
rval
=
"""
fail
=
sub
[
"fail"
]
rval
=
f
"""
if (PyArray_NDIM({x_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_val) != 1"); {fail};}}
if (PyArray_NDIM({x_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_val) != 1"); {fail};}}
if (PyArray_NDIM({x_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_ind) != 1"); {fail};}}
if (PyArray_NDIM({x_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_ind) != 1"); {fail};}}
...
@@ -888,7 +891,7 @@ class UsmmCscDense(_NoPythonCOp):
...
@@ -888,7 +891,7 @@ class UsmmCscDense(_NoPythonCOp):
}}
}}
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
return
rval
return
rval
...
@@ -985,7 +988,8 @@ class CSMGradC(_NoPythonCOp):
...
@@ -985,7 +988,8 @@ class CSMGradC(_NoPythonCOp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b_val"
)
raise
NotImplementedError
(
"Complex types are not supported for b_val"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}}
if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}}
...
@@ -1079,7 +1083,7 @@ class CSMGradC(_NoPythonCOp):
...
@@ -1079,7 +1083,7 @@ class CSMGradC(_NoPythonCOp):
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
3
,)
return
(
3
,)
...
@@ -1165,7 +1169,8 @@ class MulSDCSC(_NoPythonCOp):
...
@@ -1165,7 +1169,8 @@ class MulSDCSC(_NoPythonCOp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b"
)
raise
NotImplementedError
(
"Complex types are not supported for b"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_b}) != 2) {{
if (PyArray_NDIM({_b}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
{fail};}}
{fail};}}
...
@@ -1231,7 +1236,7 @@ class MulSDCSC(_NoPythonCOp):
...
@@ -1231,7 +1236,7 @@ class MulSDCSC(_NoPythonCOp):
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
...
@@ -1302,7 +1307,8 @@ class MulSDCSR(_NoPythonCOp):
...
@@ -1302,7 +1307,8 @@ class MulSDCSR(_NoPythonCOp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b"
)
raise
NotImplementedError
(
"Complex types are not supported for b"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_b}) != 2) {{
if (PyArray_NDIM({_b}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
{fail};}}
{fail};}}
...
@@ -1368,7 +1374,7 @@ class MulSDCSR(_NoPythonCOp):
...
@@ -1368,7 +1374,7 @@ class MulSDCSR(_NoPythonCOp):
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
...
@@ -1491,7 +1497,8 @@ class MulSVCSR(_NoPythonCOp):
...
@@ -1491,7 +1497,8 @@ class MulSVCSR(_NoPythonCOp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b"
)
raise
NotImplementedError
(
"Complex types are not supported for b"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_b}) != 1) {{
if (PyArray_NDIM({_b}) != 1) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
{fail};
{fail};
...
@@ -1553,7 +1560,7 @@ class MulSVCSR(_NoPythonCOp):
...
@@ -1553,7 +1560,7 @@ class MulSVCSR(_NoPythonCOp):
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
...
@@ -1663,7 +1670,8 @@ class StructuredAddSVCSR(_NoPythonCOp):
...
@@ -1663,7 +1670,8 @@ class StructuredAddSVCSR(_NoPythonCOp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b"
)
raise
NotImplementedError
(
"Complex types are not supported for b"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_b}) != 1) {{
if (PyArray_NDIM({_b}) != 1) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
{fail};
{fail};
...
@@ -1732,7 +1740,7 @@ class StructuredAddSVCSR(_NoPythonCOp):
...
@@ -1732,7 +1740,7 @@ class StructuredAddSVCSR(_NoPythonCOp):
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
...
@@ -1904,7 +1912,8 @@ class SamplingDotCSR(_NoPythonCOp):
...
@@ -1904,7 +1912,8 @@ class SamplingDotCSR(_NoPythonCOp):
typenum_zi
=
TensorType
(
node
.
outputs
[
1
]
.
dtype
,
[])
.
dtype_specs
()[
2
]
typenum_zi
=
TensorType
(
node
.
outputs
[
1
]
.
dtype
,
[])
.
dtype_specs
()[
2
]
typenum_zp
=
TensorType
(
node
.
outputs
[
2
]
.
dtype
,
[])
.
dtype_specs
()[
2
]
typenum_zp
=
TensorType
(
node
.
outputs
[
2
]
.
dtype
,
[])
.
dtype_specs
()[
2
]
rval
=
"""
fail
=
sub
[
"fail"
]
rval
=
f
"""
if (PyArray_NDIM({x}) != 2) {{
if (PyArray_NDIM({x}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); {fail};}}
PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); {fail};}}
if (PyArray_NDIM({y}) != 2) {{
if (PyArray_NDIM({y}) != 2) {{
...
@@ -2024,7 +2033,7 @@ PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); {fail};}}
...
@@ -2024,7 +2033,7 @@ PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); {fail};}}
}}
}}
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
return
rval
return
rval
...
...
pytensor/tensor/basic.py
浏览文件 @
55b2f4fa
...
@@ -597,12 +597,12 @@ class TensorFromScalar(COp):
...
@@ -597,12 +597,12 @@ class TensorFromScalar(COp):
(
z
,)
=
outputs
(
z
,)
=
outputs
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
{z} = (PyArrayObject*)PyArray_FromScalar(py_{x}, NULL);
{z} = (PyArrayObject*)PyArray_FromScalar(py_{x}, NULL);
if({z} == NULL){{
if({z} == NULL){{
{fail};
{fail};
}}
}}
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
2
,)
return
(
2
,)
...
@@ -646,10 +646,9 @@ class ScalarFromTensor(COp):
...
@@ -646,10 +646,9 @@ class ScalarFromTensor(COp):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
(
x
,)
=
inputs
(
x
,)
=
inputs
(
z
,)
=
outputs
(
z
,)
=
outputs
fail
=
sub
[
"fail"
]
return
f
"""
return
"""
{z} = ((dtype_{x}*)(PyArray_DATA({x})))[0];
{z} = ((dtype_{x}*)(PyArray_DATA({x})))[0];
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,)
return
(
1
,)
...
@@ -1836,18 +1835,18 @@ class MakeVector(COp):
...
@@ -1836,18 +1835,18 @@ class MakeVector(COp):
assert
self
.
dtype
==
node
.
inputs
[
0
]
.
dtype
assert
self
.
dtype
==
node
.
inputs
[
0
]
.
dtype
out_num
=
f
"PyArray_TYPE({inp[0]})"
out_num
=
f
"PyArray_TYPE({inp[0]})"
ret
=
"""
ret
=
f
"""
npy_intp dims[1];
npy_intp dims[1];
dims[0] = {out_shape};
dims[0] = {out_shape};
if(!{out} || PyArray_DIMS({out})[0] != {out_shape}){{
if(!{out} || PyArray_DIMS({out})[0] != {out_shape}){{
Py_XDECREF({out});
Py_XDECREF({out});
{out} = (PyArrayObject*)PyArray_EMPTY(1, dims, {out_num}, 0);
{out} = (PyArrayObject*)PyArray_EMPTY(1, dims, {out_num}, 0);
}}
}}
"""
.
format
(
**
locals
())
"""
for
idx
,
i
in
enumerate
(
inp
):
for
idx
,
i
in
enumerate
(
inp
):
ret
+=
"""
ret
+=
f
"""
*(({out_dtype} *)PyArray_GETPTR1({out}, {idx})) = *(({out_dtype} *) PyArray_DATA({i}));
*(({out_dtype} *)PyArray_GETPTR1({out}, {idx})) = *(({out_dtype} *) PyArray_DATA({i}));
"""
.
format
(
**
locals
())
"""
return
ret
return
ret
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
...
@@ -2225,7 +2224,7 @@ class Split(COp):
...
@@ -2225,7 +2224,7 @@ class Split(COp):
splits_dtype
=
node
.
inputs
[
2
]
.
type
.
dtype_specs
()[
1
]
splits_dtype
=
node
.
inputs
[
2
]
.
type
.
dtype_specs
()[
1
]
expected_splits_count
=
self
.
len_splits
expected_splits_count
=
self
.
len_splits
return
"""
return
f
"""
int ndim = PyArray_NDIM({x});
int ndim = PyArray_NDIM({x});
int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0));
int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0));
int splits_count = PyArray_DIM({splits}, 0);
int splits_count = PyArray_DIM({splits}, 0);
...
@@ -2322,7 +2321,7 @@ class Split(COp):
...
@@ -2322,7 +2321,7 @@ class Split(COp):
}}
}}
free(split_dims);
free(split_dims);
"""
.
format
(
**
locals
())
"""
class
Join
(
COp
):
class
Join
(
COp
):
...
@@ -2373,10 +2372,9 @@ class Join(COp):
...
@@ -2373,10 +2372,9 @@ class Join(COp):
if
self
.
view
==
-
1
:
if
self
.
view
==
-
1
:
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
else
:
else
:
return
"{}{{{}}}"
.
format
(
classname
=
self
.
__class__
.
__name__
self
.
__class__
.
__name__
,
args
=
", "
.
join
(
f
"{p}={getattr(self, p)!r}"
for
p
in
self
.
__props__
)
", "
.
join
(
f
"{p}={getattr(self, p)!r}"
for
p
in
self
.
__props__
),
return
f
"{classname}{{{args}}}"
)
def
__setstate__
(
self
,
d
):
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
self
.
__dict__
.
update
(
d
)
...
@@ -2538,7 +2536,7 @@ class Join(COp):
...
@@ -2538,7 +2536,7 @@ class Join(COp):
copy_inputs_to_list
=
"
\n
"
.
join
(
copy_to_list
)
copy_inputs_to_list
=
"
\n
"
.
join
(
copy_to_list
)
n
=
len
(
tens
)
n
=
len
(
tens
)
code
=
"""
code
=
f
"""
int axis = (({adtype} *)PyArray_DATA({axis}))[0];
int axis = (({adtype} *)PyArray_DATA({axis}))[0];
PyObject* list = PyList_New({l});
PyObject* list = PyList_New({l});
{copy_inputs_to_list}
{copy_inputs_to_list}
...
@@ -2570,7 +2568,7 @@ class Join(COp):
...
@@ -2570,7 +2568,7 @@ class Join(COp):
{fail}
{fail}
}}
}}
}}
}}
"""
.
format
(
**
locals
())
"""
return
code
return
code
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
...
@@ -4192,18 +4190,14 @@ class AllocEmpty(COp):
...
@@ -4192,18 +4190,14 @@ class AllocEmpty(COp):
params
=
sub
[
"params"
]
params
=
sub
[
"params"
]
str
=
f
"npy_intp dims[{nd}];
\n
"
str
=
f
"npy_intp dims[{nd}];
\n
"
for
idx
,
sh
in
enumerate
(
shps
):
for
idx
,
sh
in
enumerate
(
shps
):
str
+=
(
str
+=
f
"dims[{idx}] = ((npy_intp)((dtype_{sh}*) PyArray_DATA({sh}))[0]);
\n
"
"dims[{idx}] ="
"((npy_intp)((dtype_{sh}*)"
" PyArray_DATA({sh}))[0]);
\n
"
.
format
(
**
locals
())
)
# Validate that the output storage exists
# Validate that the output storage exists
str
+=
f
"if({out}==NULL
\n
"
str
+=
f
"if({out}==NULL
\n
"
for
idx
,
sh
in
enumerate
(
shps
):
for
idx
,
sh
in
enumerate
(
shps
):
str
+=
f
"||PyArray_DIMS({out})[{idx}]!=dims[{idx}]"
str
+=
f
"||PyArray_DIMS({out})[{idx}]!=dims[{idx}]"
str
+=
"""){{
str
+=
f
"""){{
/* Reference received to invalid output variable.
/* Reference received to invalid output variable.
Decrease received reference's ref count and allocate new
Decrease received reference's ref count and allocate new
output variable */
output variable */
...
@@ -4218,7 +4212,7 @@ class AllocEmpty(COp):
...
@@ -4218,7 +4212,7 @@ class AllocEmpty(COp):
{fail};
{fail};
}}
}}
}}
}}
"""
.
format
(
**
locals
())
"""
return
str
return
str
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
...
...
pytensor/tensor/blas.py
浏览文件 @
55b2f4fa
...
@@ -1806,19 +1806,12 @@ class BatchedDot(COp):
...
@@ -1806,19 +1806,12 @@ class BatchedDot(COp):
strides
=
f
"PyArray_STRIDES({var})"
strides
=
f
"PyArray_STRIDES({var})"
if
ndim
==
1
:
if
ndim
==
1
:
return
f
"{strides}[0] == type_size"
return
f
"{strides}[0] == type_size"
return
" && "
.
join
(
ands
=
" && "
.
join
(
[
f
"{strides}[{i}] > 0 && {strides}[{i}]
%
type_size == 0"
" && "
.
join
(
for
i
in
range
(
1
,
ndim
)
f
"{strides}[{i}] > 0 && {strides}[{i}]
%
type_size == 0"
for
i
in
range
(
1
,
ndim
)
),
"({})"
.
format
(
" || "
.
join
(
f
"{strides}[{i}] == type_size"
for
i
in
range
(
1
,
ndim
)
)
),
]
)
)
ors
=
" || "
.
join
(
f
"{strides}[{i}] == type_size"
for
i
in
range
(
1
,
ndim
))
return
f
"{ands} && ({ors})"
x_ndim
,
y_ndim
,
z_ndim
=
(
x_ndim
,
y_ndim
,
z_ndim
=
(
node
.
inputs
[
0
]
.
ndim
,
node
.
inputs
[
0
]
.
ndim
,
...
@@ -1838,7 +1831,7 @@ class BatchedDot(COp):
...
@@ -1838,7 +1831,7 @@ class BatchedDot(COp):
)
)
z_shape
=
", "
.
join
(
z_dims
)
z_shape
=
", "
.
join
(
z_dims
)
z_contiguous
=
contiguous
(
_z
,
z_ndim
)
z_contiguous
=
contiguous
(
_z
,
z_ndim
)
allocate
=
"""
allocate
=
f
"""
if (NULL == {_z} || !({z_shape_correct}) || !({z_contiguous}))
if (NULL == {_z} || !({z_shape_correct}) || !({z_contiguous}))
{{
{{
npy_intp dims[{z_ndim}] = {{{z_shape}}};
npy_intp dims[{z_ndim}] = {{{z_shape}}};
...
@@ -1851,14 +1844,14 @@ class BatchedDot(COp):
...
@@ -1851,14 +1844,14 @@ class BatchedDot(COp):
{fail}
{fail}
}}
}}
}}
}}
"""
.
format
(
**
locals
())
"""
# code to reallocate inputs contiguously if necessary
# code to reallocate inputs contiguously if necessary
contiguate
=
[]
contiguate
=
[]
for
var
,
ndim
in
[(
_x
,
x_ndim
),
(
_y
,
y_ndim
)]:
for
var
,
ndim
in
[(
_x
,
x_ndim
),
(
_y
,
y_ndim
)]:
_contiguous
=
contiguous
(
var
,
ndim
)
_contiguous
=
contiguous
(
var
,
ndim
)
contiguate
.
append
(
contiguate
.
append
(
"""
f
"""
if (!({_contiguous})) {{
if (!({_contiguous})) {{
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy({var});
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy({var});
if (!_copy)
if (!_copy)
...
@@ -1866,11 +1859,11 @@ class BatchedDot(COp):
...
@@ -1866,11 +1859,11 @@ class BatchedDot(COp):
Py_XDECREF({var});
Py_XDECREF({var});
{var} = _copy;
{var} = _copy;
}}
}}
"""
.
format
(
**
locals
())
"""
)
)
contiguate
=
"
\n
"
.
join
(
contiguate
)
contiguate
=
"
\n
"
.
join
(
contiguate
)
return
"""
return
f
"""
int type_num = PyArray_DESCR({_x})->type_num;
int type_num = PyArray_DESCR({_x})->type_num;
int type_size = PyArray_DESCR({_x})->elsize; // in bytes
int type_size = PyArray_DESCR({_x})->elsize; // in bytes
...
@@ -1927,7 +1920,7 @@ class BatchedDot(COp):
...
@@ -1927,7 +1920,7 @@ class BatchedDot(COp):
}}
}}
break;
break;
}}
}}
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
from
pytensor.tensor.blas_headers
import
blas_header_version
from
pytensor.tensor.blas_headers
import
blas_header_version
...
...
pytensor/tensor/blas_c.py
浏览文件 @
55b2f4fa
...
@@ -33,7 +33,7 @@ class BaseBLAS(COp):
...
@@ -33,7 +33,7 @@ class BaseBLAS(COp):
def
ger_c_code
(
A
,
a
,
x
,
y
,
Z
,
fail
,
params
):
def
ger_c_code
(
A
,
a
,
x
,
y
,
Z
,
fail
,
params
):
return
"""
return
f
"""
int elemsize ;
int elemsize ;
...
@@ -309,7 +309,7 @@ def ger_c_code(A, a, x, y, Z, fail, params):
...
@@ -309,7 +309,7 @@ def ger_c_code(A, a, x, y, Z, fail, params):
}}
}}
}}
}}
"""
.
format
(
**
locals
())
"""
class
CGer
(
BaseBLAS
,
Ger
):
class
CGer
(
BaseBLAS
,
Ger
):
...
...
pytensor/tensor/blas_headers.py
浏览文件 @
55b2f4fa
...
@@ -1066,7 +1066,7 @@ def blas_header_version():
...
@@ -1066,7 +1066,7 @@ def blas_header_version():
def
____gemm_code
(
check_ab
,
a_init
,
b_init
):
def
____gemm_code
(
check_ab
,
a_init
,
b_init
):
mod
=
"
%
"
mod
=
"
%
"
return
"""
return
f
"""
const char * error_string = NULL;
const char * error_string = NULL;
int type_num = PyArray_DESCR(_x)->type_num;
int type_num = PyArray_DESCR(_x)->type_num;
...
@@ -1203,4 +1203,4 @@ def ____gemm_code(check_ab, a_init, b_init):
...
@@ -1203,4 +1203,4 @@ def ____gemm_code(check_ab, a_init, b_init):
return -1;
return -1;
/* v 1 */
/* v 1 */
"""
.
format
(
**
locals
())
"""
pytensor/tensor/elemwise.py
浏览文件 @
55b2f4fa
...
@@ -302,10 +302,7 @@ class DimShufflePrinter(Printer):
...
@@ -302,10 +302,7 @@ class DimShufflePrinter(Printer):
return
pstate
.
pprinter
.
process
(
r
)
return
pstate
.
pprinter
.
process
(
r
)
if
list
(
new_order
)
==
list
(
reversed
(
range
(
r
.
type
.
ndim
))):
if
list
(
new_order
)
==
list
(
reversed
(
range
(
r
.
type
.
ndim
))):
return
f
"{pstate.pprinter.process(r)}.T"
return
f
"{pstate.pprinter.process(r)}.T"
return
"DimShuffle{{{}}}({})"
.
format
(
return
f
"DimShuffle{{{', '.join(str(o) for o in new_order)}}}({pstate.pprinter.process(r)})"
", "
.
join
(
map
(
str
,
new_order
)),
pstate
.
pprinter
.
process
(
r
),
)
def
process
(
self
,
r
,
pstate
):
def
process
(
self
,
r
,
pstate
):
if
r
.
owner
is
None
:
if
r
.
owner
is
None
:
...
@@ -929,13 +926,13 @@ class Elemwise(OpenMPOp):
...
@@ -929,13 +926,13 @@ class Elemwise(OpenMPOp):
# We make the output point to the corresponding input and
# We make the output point to the corresponding input and
# decrease the reference of whatever the output contained
# decrease the reference of whatever the output contained
# prior to this
# prior to this
alloc
+=
"""
alloc
+=
f
"""
if ({oname}) {{
if ({oname}) {{
Py_XDECREF({oname});
Py_XDECREF({oname});
}}
}}
{oname} = {iname};
{oname} = {iname};
Py_XINCREF({oname});
Py_XINCREF({oname});
"""
.
format
(
**
locals
())
"""
# We alias the scalar variables
# We alias the scalar variables
defines
+=
f
"#define {oname}_i {iname}_i
\n
"
defines
+=
f
"#define {oname}_i {iname}_i
\n
"
undefs
+=
f
"#undef {oname}_i
\n
"
undefs
+=
f
"#undef {oname}_i
\n
"
...
@@ -958,13 +955,13 @@ class Elemwise(OpenMPOp):
...
@@ -958,13 +955,13 @@ class Elemwise(OpenMPOp):
[
f
"{s}_i"
for
s
in
onames
],
[
f
"{s}_i"
for
s
in
onames
],
dict
(
sub
,
fail
=
fail
),
dict
(
sub
,
fail
=
fail
),
)
)
code
=
"""
code
=
f
"""
{{
{{
{defines}
{defines}
{task_code}
{task_code}
{undefs}
{undefs}
}}
}}
"""
.
format
(
**
locals
())
"""
loop_orders
=
orders
+
[
list
(
range
(
nnested
))]
*
len
(
real_onames
)
loop_orders
=
orders
+
[
list
(
range
(
nnested
))]
*
len
(
real_onames
)
dtypes
=
idtypes
+
list
(
real_odtypes
)
dtypes
=
idtypes
+
list
(
real_odtypes
)
...
@@ -994,19 +991,17 @@ class Elemwise(OpenMPOp):
...
@@ -994,19 +991,17 @@ class Elemwise(OpenMPOp):
if
index
!=
"x"
:
if
index
!=
"x"
:
preloops
.
setdefault
(
j
,
""
)
preloops
.
setdefault
(
j
,
""
)
preloops
[
j
]
+=
(
preloops
[
j
]
+=
(
"
%
(lv{i})s_iter = ({dtype}*)"
f
"
%
(lv{i})s_iter = ({dtype}*)(PyArray_DATA(
%
(lv{i})s));
\n
"
"(PyArray_DATA(
%
(lv{i})s));
\n
"
.
format
(
**
locals
())
)
%
sub
)
%
sub
break
break
else
:
# all broadcastable
else
:
# all broadcastable
preloops
.
setdefault
(
0
,
""
)
preloops
.
setdefault
(
0
,
""
)
preloops
[
0
]
+=
(
preloops
[
0
]
+=
(
"
%
(lv{i})s_iter = ({dtype}*)"
f
"
%
(lv{i})s_iter = ({dtype}*)(PyArray_DATA(
%
(lv{i})s));
\n
"
"(PyArray_DATA(
%
(lv{i})s));
\n
"
.
format
(
**
locals
())
)
%
sub
)
%
sub
init_array
=
preloops
.
get
(
0
,
" "
)
init_array
=
preloops
.
get
(
0
,
" "
)
loop
=
"""
loop
=
f
"""
{{
{{
{defines}
{defines}
{init_array}
{init_array}
...
@@ -1014,7 +1009,7 @@ class Elemwise(OpenMPOp):
...
@@ -1014,7 +1009,7 @@ class Elemwise(OpenMPOp):
{task_code}
{task_code}
{undefs}
{undefs}
}}
}}
"""
.
format
(
**
locals
())
"""
else
:
else
:
loop
=
cgen
.
make_loop
(
loop
=
cgen
.
make_loop
(
loop_orders
=
loop_orders
,
loop_orders
=
loop_orders
,
...
@@ -1074,25 +1069,25 @@ class Elemwise(OpenMPOp):
...
@@ -1074,25 +1069,25 @@ class Elemwise(OpenMPOp):
index
=
""
index
=
""
for
x
,
var
in
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
):
for
x
,
var
in
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
):
if
not
all
(
s
==
1
for
s
in
var
.
type
.
shape
):
if
not
all
(
s
==
1
for
s
in
var
.
type
.
shape
):
contig
+=
"""
contig
+=
f
"""
dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x});
dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x});
"""
.
format
(
**
locals
())
"""
index
+=
"""
index
+=
f
"""
dtype_{x}& {x}_i = {x}_ptr[i];
dtype_{x}& {x}_i = {x}_ptr[i];
"""
.
format
(
**
locals
())
"""
else
:
else
:
contig
+=
"""
contig
+=
f
"""
dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0];
dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0];
"""
.
format
(
**
locals
())
"""
if
self
.
openmp
:
if
self
.
openmp
:
contig
+=
f
"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)})
contig
+=
f
"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)})
"""
"""
contig
+=
"""
contig
+=
f
"""
for(int i=0; i<n; i++){{
for(int i=0; i<n; i++){{
{index}
{index}
{task_code};
{task_code};
}}
}}
"""
.
format
(
**
locals
())
"""
if
contig
is
not
None
:
if
contig
is
not
None
:
z
=
list
(
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
))
z
=
list
(
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
))
all_broadcastable
=
all
(
s
==
1
for
s
in
var
.
type
.
shape
)
all_broadcastable
=
all
(
s
==
1
for
s
in
var
.
type
.
shape
)
...
@@ -1106,13 +1101,13 @@ class Elemwise(OpenMPOp):
...
@@ -1106,13 +1101,13 @@ class Elemwise(OpenMPOp):
for
arr
,
var
in
z
for
arr
,
var
in
z
if
not
all_broadcastable
if
not
all_broadcastable
)
)
loop
=
"""
loop
=
f
"""
if(({cond1}) || ({cond2})){{
if(({cond1}) || ({cond2})){{
{contig}
{contig}
}}else{{
}}else{{
{loop}
{loop}
}}
}}
"""
.
format
(
**
locals
())
"""
return
decl
,
checks
,
alloc
,
loop
,
""
return
decl
,
checks
,
alloc
,
loop
,
""
def
c_code
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
...
...
pytensor/tensor/elemwise_cgen.py
浏览文件 @
55b2f4fa
...
@@ -170,12 +170,14 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
...
@@ -170,12 +170,14 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
type
=
type
.
replace
(
"PYTENSOR_COMPLEX"
,
"NPY_COMPLEX"
)
type
=
type
.
replace
(
"PYTENSOR_COMPLEX"
,
"NPY_COMPLEX"
)
nd
=
len
(
loop_orders
[
0
])
nd
=
len
(
loop_orders
[
0
])
init_dims
=
compute_output_dims_lengths
(
"dims"
,
loop_orders
,
sub
)
init_dims
=
compute_output_dims_lengths
(
"dims"
,
loop_orders
,
sub
)
olv
=
sub
[
"olv"
]
fail
=
sub
[
"fail"
]
# TODO: it would be interesting to allocate the output in such a
# TODO: it would be interesting to allocate the output in such a
# way that its contiguous dimensions match one of the input's
# way that its contiguous dimensions match one of the input's
# contiguous dimensions, or the dimension with the smallest
# contiguous dimensions, or the dimension with the smallest
# stride. Right now, it is allocated to be C_CONTIGUOUS.
# stride. Right now, it is allocated to be C_CONTIGUOUS.
return
"""
return
f
"""
{{
{{
npy_intp dims[{nd}];
npy_intp dims[{nd}];
//npy_intp* dims = (npy_intp*)malloc({nd} * sizeof(npy_intp));
//npy_intp* dims = (npy_intp*)malloc({nd} * sizeof(npy_intp));
...
@@ -203,7 +205,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
...
@@ -203,7 +205,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
{fail}
{fail}
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
make_loop
(
loop_orders
,
dtypes
,
loop_tasks
,
sub
,
openmp
=
None
):
def
make_loop
(
loop_orders
,
dtypes
,
loop_tasks
,
sub
,
openmp
=
None
):
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
55b2f4fa
...
@@ -68,7 +68,7 @@ class CpuContiguous(COp):
...
@@ -68,7 +68,7 @@ class CpuContiguous(COp):
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
(
x
,)
=
inames
(
x
,)
=
inames
(
y
,)
=
onames
(
y
,)
=
onames
code
=
"""
code
=
f
"""
if (!PyArray_CHKFLAGS({x}, NPY_ARRAY_C_CONTIGUOUS)){{
if (!PyArray_CHKFLAGS({x}, NPY_ARRAY_C_CONTIGUOUS)){{
// check to see if output is contiguous first
// check to see if output is contiguous first
if ({y} != NULL &&
if ({y} != NULL &&
...
@@ -86,7 +86,7 @@ class CpuContiguous(COp):
...
@@ -86,7 +86,7 @@ class CpuContiguous(COp):
Py_XDECREF({y});
Py_XDECREF({y});
{y} = {x};
{y} = {x};
}}
}}
"""
.
format
(
**
locals
())
"""
return
code
return
code
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
...
@@ -161,13 +161,13 @@ class SearchsortedOp(COp):
...
@@ -161,13 +161,13 @@ class SearchsortedOp(COp):
def
c_init_code_struct
(
self
,
node
,
name
,
sub
):
def
c_init_code_struct
(
self
,
node
,
name
,
sub
):
side
=
sub
[
"params"
]
side
=
sub
[
"params"
]
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
PyObject* tmp_{name} = PyUnicode_FromString("right");
PyObject* tmp_{name} = PyUnicode_FromString("right");
if (tmp_{name} == NULL)
if (tmp_{name} == NULL)
{fail};
{fail};
right_{name} = PyUnicode_Compare({side}, tmp_{name});
right_{name} = PyUnicode_Compare({side}, tmp_{name});
Py_DECREF(tmp_{name});
Py_DECREF(tmp_{name});
"""
.
format
(
**
locals
())
"""
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
sorter
=
None
sorter
=
None
...
@@ -180,7 +180,7 @@ class SearchsortedOp(COp):
...
@@ -180,7 +180,7 @@ class SearchsortedOp(COp):
(
z
,)
=
onames
(
z
,)
=
onames
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
Py_XDECREF({z});
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SearchSorted({x}, (PyObject*) {v},
{z} = (PyArrayObject*) PyArray_SearchSorted({x}, (PyObject*) {v},
right_{name} ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) {sorter});
right_{name} ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) {sorter});
...
@@ -191,7 +191,7 @@ class SearchsortedOp(COp):
...
@@ -191,7 +191,7 @@ class SearchsortedOp(COp):
Py_XDECREF({z});
Py_XDECREF({z});
{z} = (PyArrayObject*) tmp;
{z} = (PyArrayObject*) tmp;
}}
}}
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
2
,)
return
(
2
,)
...
@@ -348,11 +348,10 @@ class CumOp(COp):
...
@@ -348,11 +348,10 @@ class CumOp(COp):
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
(
x
,)
=
inames
(
x
,)
=
inames
(
z
,)
=
onames
(
z
,)
=
onames
axis
=
self
.
axis
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
params
=
sub
[
"params"
]
params
=
sub
[
"params"
]
code
=
"""
code
=
f
"""
int axis = {params}->c_axis;
int axis = {params}->c_axis;
if (axis == 0 && PyArray_NDIM({x}) == 1)
if (axis == 0 && PyArray_NDIM({x}) == 1)
axis = NPY_MAXDIMS;
axis = NPY_MAXDIMS;
...
@@ -389,7 +388,7 @@ class CumOp(COp):
...
@@ -389,7 +388,7 @@ class CumOp(COp):
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
Py_XDECREF(t);
}}
}}
"""
.
format
(
**
locals
())
"""
return
code
return
code
...
...
pytensor/tensor/math.py
浏览文件 @
55b2f4fa
...
@@ -215,14 +215,14 @@ class Argmax(COp):
...
@@ -215,14 +215,14 @@ class Argmax(COp):
if
len
(
self
.
axis
)
!=
1
:
if
len
(
self
.
axis
)
!=
1
:
raise
NotImplementedError
()
raise
NotImplementedError
()
# params is only used here for now
# params is only used here for now
axis_code
=
"""
axis_code
=
f
"""
axis = {params}->c_axis;
axis = {params}->c_axis;
if(axis > PyArray_NDIM({x})-1 || axis < -PyArray_NDIM({x})){{
if(axis > PyArray_NDIM({x})-1 || axis < -PyArray_NDIM({x})){{
PyErr_SetString(PyExc_ValueError,
PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument");
"Argmax, bad axis argument");
{fail}
{fail}
}}
}}
"""
.
format
(
**
locals
())
"""
ret
=
"""
ret
=
"""
int axis;
int axis;
...
@@ -1314,7 +1314,8 @@ class Mean(FixedOpCAReduce):
...
@@ -1314,7 +1314,8 @@ class Mean(FixedOpCAReduce):
def
__str__
(
self
):
def
__str__
(
self
):
if
self
.
axis
is
not
None
:
if
self
.
axis
is
not
None
:
return
"Mean{{{}}}"
.
format
(
", "
.
join
(
str
(
x
)
for
x
in
self
.
axis
))
args
=
", "
.
join
(
str
(
x
)
for
x
in
self
.
axis
)
return
f
"Mean{{{args}}}"
else
:
else
:
return
"Mean"
return
"Mean"
...
...
pytensor/tensor/subtensor.py
浏览文件 @
55b2f4fa
...
@@ -1137,7 +1137,7 @@ class Subtensor(COp):
...
@@ -1137,7 +1137,7 @@ class Subtensor(COp):
"""
"""
rval
+=
"""
rval
+=
f
"""
// One more argument of the view
// One more argument of the view
npy_intp xview_offset = 0;
npy_intp xview_offset = 0;
...
@@ -1264,7 +1264,7 @@ class Subtensor(COp):
...
@@ -1264,7 +1264,7 @@ class Subtensor(COp):
inner_ii += 1;
inner_ii += 1;
outer_ii += 1;
outer_ii += 1;
}}
}}
"""
.
format
(
**
locals
())
"""
# print rval
# print rval
return
rval
return
rval
...
@@ -1284,19 +1284,19 @@ class Subtensor(COp):
...
@@ -1284,19 +1284,19 @@ class Subtensor(COp):
decl
=
"PyArrayObject * xview = NULL;"
decl
=
"PyArrayObject * xview = NULL;"
checkNDim
=
"""
checkNDim
=
f
"""
if (PyArray_NDIM({x}) != {ndim}){{
if (PyArray_NDIM({x}) != {ndim}){{
PyErr_SetString(PyExc_ValueError,
PyErr_SetString(PyExc_ValueError,
"Expected {ndim} dimensions input"
"Expected {ndim} dimensions input"
);
);
{fail}
{fail}
}}
}}
"""
.
format
(
**
locals
())
"""
get_xview
=
self
.
helper_c_code
(
get_xview
=
self
.
helper_c_code
(
node
,
name
,
inputs
,
outputs
,
sub
,
self
.
idx_list
,
view_ndim
node
,
name
,
inputs
,
outputs
,
sub
,
self
.
idx_list
,
view_ndim
)
)
build_view
=
"""
build_view
=
f
"""
//TODO: give this Op a second output so that this view can be cached
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
//TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR({x}));
Py_INCREF(PyArray_DESCR({x}));
...
@@ -1314,7 +1314,7 @@ class Subtensor(COp):
...
@@ -1314,7 +1314,7 @@ class Subtensor(COp):
{{
{{
{fail};
{fail};
}}
}}
"""
.
format
(
**
locals
())
"""
finish_view
=
f
"""
finish_view
=
f
"""
Py_XDECREF({z});
Py_XDECREF({z});
...
@@ -1761,7 +1761,7 @@ class IncSubtensor(COp):
...
@@ -1761,7 +1761,7 @@ class IncSubtensor(COp):
copy_of_x
=
self
.
copy_of_x
(
x
)
copy_of_x
=
self
.
copy_of_x
(
x
)
copy_input_if_necessary
=
"""
copy_input_if_necessary
=
f
"""
if ({inplace})
if ({inplace})
{{
{{
if ({x} != {z})
if ({x} != {z})
...
@@ -1780,7 +1780,7 @@ class IncSubtensor(COp):
...
@@ -1780,7 +1780,7 @@ class IncSubtensor(COp):
{fail}
{fail}
}}
}}
}}
}}
"""
.
format
(
**
locals
())
"""
# get info needed to make zview: a view of %(z)s
# get info needed to make zview: a view of %(z)s
helper_args
=
self
.
get_helper_c_code_args
()
helper_args
=
self
.
get_helper_c_code_args
()
...
@@ -1799,7 +1799,7 @@ class IncSubtensor(COp):
...
@@ -1799,7 +1799,7 @@ class IncSubtensor(COp):
# Make a view on the output, as we will write into it.
# Make a view on the output, as we will write into it.
alloc_zview
=
self
.
make_view_array
(
z
,
view_ndim
)
alloc_zview
=
self
.
make_view_array
(
z
,
view_ndim
)
build_view
=
"""
build_view
=
f
"""
//TODO: give this Op a second output so that this view can be cached
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
//TODO: alternatively, fix the memory leak on failure
{alloc_zview};
{alloc_zview};
...
@@ -1807,13 +1807,13 @@ class IncSubtensor(COp):
...
@@ -1807,13 +1807,13 @@ class IncSubtensor(COp):
{{
{{
{fail};
{fail};
}}
}}
"""
.
format
(
**
locals
())
"""
copy_into
=
self
.
copy_into
(
"zview"
,
y
)
copy_into
=
self
.
copy_into
(
"zview"
,
y
)
add_to_zview
=
self
.
add_to_zview
(
name
,
y
,
fail
)
add_to_zview
=
self
.
add_to_zview
(
name
,
y
,
fail
)
make_modification
=
"""
make_modification
=
f
"""
if ({op_is_set})
if ({op_is_set})
{{
{{
if ({copy_into}) // does broadcasting
if ({copy_into}) // does broadcasting
...
@@ -1826,7 +1826,7 @@ class IncSubtensor(COp):
...
@@ -1826,7 +1826,7 @@ class IncSubtensor(COp):
{{
{{
{add_to_zview}
{add_to_zview}
}}
}}
"""
.
format
(
**
locals
())
"""
return
(
return
(
self
.
decl_view
()
self
.
decl_view
()
+
copy_input_if_necessary
+
copy_input_if_necessary
...
@@ -1896,7 +1896,7 @@ class IncSubtensor(COp):
...
@@ -1896,7 +1896,7 @@ class IncSubtensor(COp):
"""
"""
return
"""Py_INCREF(PyArray_DESCR({x}));
return
f
"""Py_INCREF(PyArray_DESCR({x}));
zview = (PyArrayObject*)PyArray_NewFromDescr(
zview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
&PyArray_Type,
PyArray_DESCR({x}),
PyArray_DESCR({x}),
...
@@ -1906,7 +1906,7 @@ class IncSubtensor(COp):
...
@@ -1906,7 +1906,7 @@ class IncSubtensor(COp):
PyArray_BYTES({x}) + xview_offset, //PyArray_DATA({x}),
PyArray_BYTES({x}) + xview_offset, //PyArray_DATA({x}),
PyArray_FLAGS({x}),
PyArray_FLAGS({x}),
NULL);
NULL);
"""
.
format
(
**
locals
())
"""
def
get_helper_c_code_args
(
self
):
def
get_helper_c_code_args
(
self
):
"""
"""
...
@@ -1939,7 +1939,7 @@ class IncSubtensor(COp):
...
@@ -1939,7 +1939,7 @@ class IncSubtensor(COp):
"""
"""
return
"""
return
f
"""
PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd(
PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd(
(PyObject*)zview, py_{x});
(PyObject*)zview, py_{x});
if (add_rval)
if (add_rval)
...
@@ -1952,7 +1952,7 @@ class IncSubtensor(COp):
...
@@ -1952,7 +1952,7 @@ class IncSubtensor(COp):
{{
{{
Py_DECREF(zview);
Py_DECREF(zview);
{fail};
{fail};
}}"""
.
format
(
**
locals
())
}}"""
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
0
]]
return
[
shapes
[
0
]]
...
@@ -2162,7 +2162,7 @@ class AdvancedSubtensor1(COp):
...
@@ -2162,7 +2162,7 @@ class AdvancedSubtensor1(COp):
a_name
,
i_name
=
input_names
[
0
],
input_names
[
1
]
a_name
,
i_name
=
input_names
[
0
],
input_names
[
1
]
output_name
=
output_names
[
0
]
output_name
=
output_names
[
0
]
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
PyArrayObject *indices;
PyArrayObject *indices;
int i_type = PyArray_TYPE({i_name});
int i_type = PyArray_TYPE({i_name});
if (i_type != NPY_INTP) {{
if (i_type != NPY_INTP) {{
...
@@ -2237,7 +2237,7 @@ class AdvancedSubtensor1(COp):
...
@@ -2237,7 +2237,7 @@ class AdvancedSubtensor1(COp):
{a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE);
{a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE);
Py_DECREF(indices);
Py_DECREF(indices);
if ({output_name} == NULL) {fail};
if ({output_name} == NULL) {fail};
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
0
,
1
,
2
)
return
(
0
,
1
,
2
)
...
@@ -2523,8 +2523,10 @@ class AdvancedIncSubtensor1(COp):
...
@@ -2523,8 +2523,10 @@ class AdvancedIncSubtensor1(COp):
x
,
y
,
idx
=
input_names
x
,
y
,
idx
=
input_names
out
=
output_names
[
0
]
out
=
output_names
[
0
]
copy_of_x
=
self
.
copy_of_x
(
x
)
copy_of_x
=
self
.
copy_of_x
(
x
)
params
=
sub
[
"params"
]
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
PyObject* rval = NULL;
PyObject* rval = NULL;
if ({params}->inplace)
if ({params}->inplace)
{{
{{
...
@@ -2548,17 +2550,7 @@ class AdvancedIncSubtensor1(COp):
...
@@ -2548,17 +2550,7 @@ class AdvancedIncSubtensor1(COp):
{fail};
{fail};
}}
}}
Py_XDECREF(rval);
Py_XDECREF(rval);
"""
.
format
(
"""
**
dict
(
x
=
x
,
y
=
y
,
idx
=
idx
,
out
=
out
,
copy_of_x
=
copy_of_x
,
params
=
sub
[
"params"
],
fail
=
sub
[
"fail"
],
)
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
8
,)
return
(
8
,)
...
...
pytensor/tensor/type.py
浏览文件 @
55b2f4fa
...
@@ -475,25 +475,28 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
...
@@ -475,25 +475,28 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
if
check_input
:
if
check_input
:
check
=
"""
dtype
=
self
.
dtype_specs
()[
1
]
check
=
f
"""
typedef {dtype} dtype_{name};
typedef {dtype} dtype_{name};
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
]))
"""
else
:
else
:
check
=
""
check
=
""
declaration
=
"""
declaration
=
f
"""
PyArrayObject* {name};
PyArrayObject* {name};
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
]))
"""
return
declaration
+
check
return
declaration
+
check
def
c_init
(
self
,
name
,
sub
):
def
c_init
(
self
,
name
,
sub
):
return
"""
return
f
"""
{name} = NULL;
{name} = NULL;
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]))
"""
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
if
check_input
:
if
check_input
:
check
=
"""
fail
=
sub
[
"fail"
]
type_num
=
self
.
dtype_specs
()[
2
]
check
=
f
"""
{name} = NULL;
{name} = NULL;
if (py_{name} == Py_None) {{
if (py_{name} == Py_None) {{
// We can either fail here or set {name} to NULL and rely on Ops
// We can either fail here or set {name} to NULL and rely on Ops
...
@@ -541,28 +544,27 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
...
@@ -541,28 +544,27 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
{type_num}, PyArray_TYPE((PyArrayObject*) py_{name}));
{type_num}, PyArray_TYPE((PyArrayObject*) py_{name}));
{fail}
{fail}
}}
}}
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]))
"""
else
:
else
:
check
=
""
check
=
""
return
(
return
(
check
check
+
"""
+
f
"""
{name} = (PyArrayObject*)(py_{name});
{name} = (PyArrayObject*)(py_{name});
Py_XINCREF({name});
Py_XINCREF({name});
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]))
"""
)
)
def
c_cleanup
(
self
,
name
,
sub
):
def
c_cleanup
(
self
,
name
,
sub
):
return
"""
return
f
"""
if ({name}) {{
if ({name}) {{
Py_XDECREF({name});
Py_XDECREF({name});
}}
}}
"""
.
format
(
**
locals
())
"""
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
type_num
=
self
.
dtype_specs
()[
2
]
return
f
"""
return
"""
{{Py_XDECREF(py_{name});}}
{{Py_XDECREF(py_{name});}}
if (!{name}) {{
if (!{name}) {{
Py_INCREF(Py_None);
Py_INCREF(Py_None);
...
@@ -597,7 +599,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
...
@@ -597,7 +599,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
);
);
{fail}
{fail}
}}
}}
"""
.
format
(
**
locals
())
"""
def
c_headers
(
self
,
**
kwargs
):
def
c_headers
(
self
,
**
kwargs
):
return
ps
.
get_scalar_type
(
self
.
dtype
)
.
c_headers
(
**
kwargs
)
return
ps
.
get_scalar_type
(
self
.
dtype
)
.
c_headers
(
**
kwargs
)
...
...
pytensor/typed_list/basic.py
浏览文件 @
55b2f4fa
...
@@ -102,13 +102,13 @@ class GetItem(COp):
...
@@ -102,13 +102,13 @@ class GetItem(COp):
x_name
,
index
=
inp
[
0
],
inp
[
1
]
x_name
,
index
=
inp
[
0
],
inp
[
1
]
output_name
=
out
[
0
]
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
{output_name} = (typeof {output_name}) PyList_GetItem( (PyObject*) {x_name}, *((npy_int64 *) PyArray_DATA({index})));
{output_name} = (typeof {output_name}) PyList_GetItem( (PyObject*) {x_name}, *((npy_int64 *) PyArray_DATA({index})));
if({output_name} == NULL){{
if({output_name} == NULL){{
{fail}
{fail}
}}
}}
Py_INCREF({output_name});
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,)
return
(
1
,)
...
@@ -169,16 +169,16 @@ class Append(COp):
...
@@ -169,16 +169,16 @@ class Append(COp):
output_name
=
out
[
0
]
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
if
not
self
.
inplace
:
init
=
"""
init
=
f
"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
"""
.
format
(
**
locals
())
"""
else
:
else
:
init
=
f
"""
init
=
f
"""
{output_name} = {x_name};
{output_name} = {x_name};
"""
"""
return
(
return
(
init
init
+
"""
+
f
"""
if({output_name}==NULL){{
if({output_name}==NULL){{
{fail}
{fail}
}};
}};
...
@@ -186,7 +186,7 @@ class Append(COp):
...
@@ -186,7 +186,7 @@ class Append(COp):
{fail}
{fail}
}};
}};
Py_INCREF({output_name});
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
)
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
...
@@ -249,16 +249,16 @@ class Extend(COp):
...
@@ -249,16 +249,16 @@ class Extend(COp):
output_name
=
out
[
0
]
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
if
not
self
.
inplace
:
init
=
"""
init
=
f
"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
"""
.
format
(
**
locals
())
"""
else
:
else
:
init
=
f
"""
init
=
f
"""
{output_name} = {x_name};
{output_name} = {x_name};
"""
"""
return
(
return
(
init
init
+
"""
+
f
"""
int i =0;
int i =0;
int length = PyList_GET_SIZE((PyObject*) {toAppend});
int length = PyList_GET_SIZE((PyObject*) {toAppend});
if({output_name}==NULL){{
if({output_name}==NULL){{
...
@@ -270,7 +270,7 @@ class Extend(COp):
...
@@ -270,7 +270,7 @@ class Extend(COp):
}};
}};
}}
}}
Py_INCREF({output_name});
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
)
)
def
c_code_cache_version_
(
self
):
def
c_code_cache_version_
(
self
):
...
@@ -337,16 +337,16 @@ class Insert(COp):
...
@@ -337,16 +337,16 @@ class Insert(COp):
output_name
=
out
[
0
]
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
if
not
self
.
inplace
:
init
=
"""
init
=
f
"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
"""
.
format
(
**
locals
())
"""
else
:
else
:
init
=
f
"""
init
=
f
"""
{output_name} = {x_name};
{output_name} = {x_name};
"""
"""
return
(
return
(
init
init
+
"""
+
f
"""
if({output_name}==NULL){{
if({output_name}==NULL){{
{fail}
{fail}
}};
}};
...
@@ -354,7 +354,7 @@ class Insert(COp):
...
@@ -354,7 +354,7 @@ class Insert(COp):
{fail}
{fail}
}};
}};
Py_INCREF({output_name});
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
)
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
...
@@ -465,16 +465,16 @@ class Reverse(COp):
...
@@ -465,16 +465,16 @@ class Reverse(COp):
output_name
=
out
[
0
]
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
if
not
self
.
inplace
:
init
=
"""
init
=
f
"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
"""
.
format
(
**
locals
())
"""
else
:
else
:
init
=
f
"""
init
=
f
"""
{output_name} = {x_name};
{output_name} = {x_name};
"""
"""
return
(
return
(
init
init
+
"""
+
f
"""
if({output_name}==NULL){{
if({output_name}==NULL){{
{fail}
{fail}
}};
}};
...
@@ -482,7 +482,7 @@ class Reverse(COp):
...
@@ -482,7 +482,7 @@ class Reverse(COp):
{fail}
{fail}
}};
}};
Py_INCREF({output_name});
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
)
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
...
@@ -595,13 +595,12 @@ class Length(COp):
...
@@ -595,13 +595,12 @@ class Length(COp):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
x_name
=
inp
[
0
]
x_name
=
inp
[
0
]
output_name
=
out
[
0
]
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
return
f
"""
return
"""
if(!{output_name})
if(!{output_name})
{output_name}=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
{output_name}=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA({output_name}))[0]=PyList_Size((PyObject*){x_name});
((npy_int64*)PyArray_DATA({output_name}))[0]=PyList_Size((PyObject*){x_name});
Py_INCREF({output_name});
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,)
return
(
1
,)
...
...
pytensor/typed_list/type.py
浏览文件 @
55b2f4fa
...
@@ -110,26 +110,27 @@ class TypedListType(CType):
...
@@ -110,26 +110,27 @@ class TypedListType(CType):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
if
check_input
:
if
check_input
:
pre
=
"""
fail
=
sub
[
"fail"
]
pre
=
f
"""
if (!PyList_Check(py_{name})) {{
if (!PyList_Check(py_{name})) {{
PyErr_SetString(PyExc_TypeError, "expected a list");
PyErr_SetString(PyExc_TypeError, "expected a list");
{fail}
{fail}
}}"""
.
format
(
**
dict
(
name
=
name
,
fail
=
sub
[
"fail"
]))
}}"""
else
:
else
:
pre
=
""
pre
=
""
return
(
return
(
pre
pre
+
"""
+
f
"""
{name} = (PyListObject*) (py_{name});
{name} = (PyListObject*) (py_{name});
"""
.
format
(
**
dict
(
name
=
name
,
fail
=
sub
[
"fail"
]))
"""
)
)
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
return
"""
return
f
"""
Py_XDECREF(py_{name});
Py_XDECREF(py_{name});
py_{name} = (PyObject*)({name});
py_{name} = (PyObject*)({name});
Py_INCREF(py_{name});
Py_INCREF(py_{name});
"""
.
format
(
**
dict
(
name
=
name
))
"""
def
c_cleanup
(
self
,
name
,
sub
):
def
c_cleanup
(
self
,
name
,
sub
):
return
""
return
""
...
...
tests/compile/test_debugmode.py
浏览文件 @
55b2f4fa
...
@@ -64,7 +64,8 @@ class BROKEN_ON_PURPOSE_Add(COp):
...
@@ -64,7 +64,8 @@ class BROKEN_ON_PURPOSE_Add(COp):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
a
,
b
=
inp
a
,
b
=
inp
(
z
,)
=
out
(
z
,)
=
out
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({a}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 1"); {fail};}}
if (PyArray_NDIM({a}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 1"); {fail};}}
if (PyArray_NDIM({b}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); {fail};}}
if (PyArray_NDIM({b}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); {fail};}}
...
@@ -96,7 +97,7 @@ class BROKEN_ON_PURPOSE_Add(COp):
...
@@ -96,7 +97,7 @@ class BROKEN_ON_PURPOSE_Add(COp):
+ ((double*)PyArray_GETPTR1({b}, m))[0] ;
+ ((double*)PyArray_GETPTR1({b}, m))[0] ;
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
# inconsistent is a invalid op, whose perform and c_code do not match
# inconsistent is a invalid op, whose perform and c_code do not match
...
@@ -632,7 +633,8 @@ class BrokenCImplementationAdd(COp):
...
@@ -632,7 +633,8 @@ class BrokenCImplementationAdd(COp):
a
,
b
=
inp
a
,
b
=
inp
(
z
,)
=
out
(
z
,)
=
out
debug
=
0
debug
=
0
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
//printf("executing c_code
\\
n");
//printf("executing c_code
\\
n");
if (PyArray_NDIM({a}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); {fail};}}
if (PyArray_NDIM({a}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); {fail};}}
if (PyArray_NDIM({b}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); {fail};}}
if (PyArray_NDIM({b}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); {fail};}}
...
@@ -690,7 +692,7 @@ class BrokenCImplementationAdd(COp):
...
@@ -690,7 +692,7 @@ class BrokenCImplementationAdd(COp):
}}
}}
}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
class
VecAsRowAndCol
(
Op
):
class
VecAsRowAndCol
(
Op
):
...
...
tests/link/c/test_basic.py
浏览文件 @
55b2f4fa
...
@@ -39,7 +39,8 @@ class TDouble(CType):
...
@@ -39,7 +39,8 @@ class TDouble(CType):
return
str
(
data
)
return
str
(
data
)
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (!PyFloat_Check(py_{name})) {{
if (!PyFloat_Check(py_{name})) {{
PyErr_SetString(PyExc_TypeError, "not a double!");
PyErr_SetString(PyExc_TypeError, "not a double!");
{fail}
{fail}
...
@@ -47,7 +48,7 @@ class TDouble(CType):
...
@@ -47,7 +48,7 @@ class TDouble(CType):
{name} = PyFloat_AsDouble(py_{name});
{name} = PyFloat_AsDouble(py_{name});
{name}_bad_thing = NULL;
{name}_bad_thing = NULL;
//printf("Extracting {name}
\\
n");
//printf("Extracting {name}
\\
n");
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
return
f
"""
return
f
"""
...
...
tests/link/c/test_cmodule.py
浏览文件 @
55b2f4fa
...
@@ -189,12 +189,11 @@ def mock_system(request):
...
@@ -189,12 +189,11 @@ def mock_system(request):
@pytest.fixture
()
@pytest.fixture
()
def
cxx_search_dirs
(
blas_libs
,
mock_system
):
def
cxx_search_dirs
(
blas_libs
,
mock_system
):
libext
=
{
"Linux"
:
"so"
,
"Windows"
:
"dll"
,
"Darwin"
:
"dylib"
}
libext
=
{
"Linux"
:
"so"
,
"Windows"
:
"dll"
,
"Darwin"
:
"dylib"
}
libtemplate
=
f
"{{lib}}.{libext[mock_system]}"
libraries
=
[]
libraries
=
[]
with
tempfile
.
TemporaryDirectory
()
as
d
:
with
tempfile
.
TemporaryDirectory
()
as
d
:
flags
=
None
flags
=
None
for
lib
in
blas_libs
:
for
lib
in
blas_libs
:
lib_path
=
Path
(
d
)
/
libtemplate
.
format
(
lib
=
lib
)
lib_path
=
Path
(
d
)
/
f
"{lib}.{libext[mock_system]}"
lib_path
.
write_bytes
(
b
"1"
)
lib_path
.
write_bytes
(
b
"1"
)
libraries
.
append
(
lib_path
)
libraries
.
append
(
lib_path
)
if
flags
is
None
:
if
flags
is
None
:
...
@@ -262,14 +261,13 @@ def test_default_blas_ldflags_no_cxx():
...
@@ -262,14 +261,13 @@ def test_default_blas_ldflags_no_cxx():
@pytest.fixture
()
@pytest.fixture
()
def
windows_conda_libs
(
blas_libs
):
def
windows_conda_libs
(
blas_libs
):
libtemplate
=
"{lib}.dll"
libraries
=
[]
libraries
=
[]
with
tempfile
.
TemporaryDirectory
()
as
d
:
with
tempfile
.
TemporaryDirectory
()
as
d
:
subdir
=
Path
(
d
)
/
"Library"
/
"bin"
subdir
=
Path
(
d
)
/
"Library"
/
"bin"
subdir
.
mkdir
(
exist_ok
=
True
,
parents
=
True
)
subdir
.
mkdir
(
exist_ok
=
True
,
parents
=
True
)
flags
=
f
'-L"{subdir}"'
flags
=
f
'-L"{subdir}"'
for
lib
in
blas_libs
:
for
lib
in
blas_libs
:
lib_path
=
subdir
/
libtemplate
.
format
(
lib
=
lib
)
lib_path
=
subdir
/
f
"{lib}.dll"
lib_path
.
write_bytes
(
b
"1"
)
lib_path
.
write_bytes
(
b
"1"
)
libraries
.
append
(
lib_path
)
libraries
.
append
(
lib_path
)
flags
+=
f
" -l{lib}"
flags
+=
f
" -l{lib}"
...
...
tests/link/c/test_op.py
浏览文件 @
55b2f4fa
...
@@ -79,10 +79,10 @@ class StructOp(COp):
...
@@ -79,10 +79,10 @@ class StructOp(COp):
return
f
"counter{name} = 0;"
return
f
"counter{name} = 0;"
def
c_code
(
self
,
node
,
name
,
input_names
,
outputs_names
,
sub
):
def
c_code
(
self
,
node
,
name
,
input_names
,
outputs_names
,
sub
):
return
"""
return
f
"""
{out} = counter{name};
{out
puts_names[0]
} = counter{name};
counter{name}++;
counter{name}++;
"""
.
format
(
**
dict
(
out
=
outputs_names
[
0
],
name
=
name
))
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,)
return
(
1
,)
...
...
tests/link/c/test_params_type.py
浏览文件 @
55b2f4fa
...
@@ -73,30 +73,24 @@ class QuadraticOpFunc(COp):
...
@@ -73,30 +73,24 @@ class QuadraticOpFunc(COp):
"""
"""
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
return
"""
X
=
inputs
[
0
]
{float_type} a = ({float_type}) (*(npy_float64*) PyArray_GETPTR1({coeff}->a, 0)); // 0-D TensorType.
Y
=
outputs
[
0
]
{float_type} b = {coeff}->b; // ScalarType.
float_type
=
node
.
inputs
[
0
]
.
type
.
c_element_type
()
{float_type} c = ({float_type}) PyFloat_AsDouble({coeff}->c); // Generic.
return
f
"""
{float_type} a = ({float_type}) (*(npy_float64*) PyArray_GETPTR1({sub['params']}->a, 0)); // 0-D TensorType.
{float_type} b = {sub['params']}->b; // ScalarType.
{float_type} c = ({float_type}) PyFloat_AsDouble({sub['params']}->c); // Generic.
Py_XDECREF({Y});
Py_XDECREF({Y});
{Y} = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM({X}), PyArray_DIMS({X}), PyArray_TYPE({X}), PyArray_IS_F_CONTIGUOUS({X}));
{Y} = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM({X}), PyArray_DIMS({X}), PyArray_TYPE({X}), PyArray_IS_F_CONTIGUOUS({X}));
if (PyArray_CopyInto({Y}, {X}) != 0) {{
if (PyArray_CopyInto({Y}, {X}) != 0) {{
PyErr_SetString(PyExc_RuntimeError, "Unable to copy input into output.");
PyErr_SetString(PyExc_RuntimeError, "Unable to copy input into output.");
{
fail
}
{
sub['fail']
}
}};
}};
if (quadratic_{name}({Y}, a, b, c) != 0) {{
if (quadratic_{name}({Y}, a, b, c) != 0) {{
PyErr_SetString(PyExc_RuntimeError, "Unable to compute quadratic function.");
PyErr_SetString(PyExc_RuntimeError, "Unable to compute quadratic function.");
{
fail
}
{
sub['fail']
}
}}
}}
"""
.
format
(
"""
**
dict
(
name
=
name
,
coeff
=
sub
[
"params"
],
fail
=
sub
[
"fail"
],
X
=
inputs
[
0
],
Y
=
outputs
[
0
],
float_type
=
node
.
inputs
[
0
]
.
type
.
c_element_type
(),
)
)
# Same op as above, but implemented as a ExternalCOp (with C code in an
# Same op as above, but implemented as a ExternalCOp (with C code in an
...
...
tests/link/c/test_type.py
浏览文件 @
55b2f4fa
...
@@ -25,11 +25,12 @@ Py_XDECREF((PyObject *)p);
...
@@ -25,11 +25,12 @@ Py_XDECREF((PyObject *)p);
"""
"""
def
c_code
(
self
,
node
,
name
,
inps
,
outs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inps
,
outs
,
sub
):
return
"""
return
f
"""
Py_XDECREF({out});
Py_XDECREF({outs[0]});
{out} = (void *){inp};
{outs[0]} = (void *){inps[0]};
Py_INCREF({inp});
Py_INCREF({inps[0]});
"""
.
format
(
**
dict
(
out
=
outs
[
0
],
inp
=
inps
[
0
]))
"""
# FIXME: should it not be outs[0]?
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
0
,)
return
(
0
,)
...
@@ -52,11 +53,11 @@ Py_XDECREF((PyObject *)p);
...
@@ -52,11 +53,11 @@ Py_XDECREF((PyObject *)p);
"""
"""
def
c_code
(
self
,
node
,
name
,
inps
,
outs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inps
,
outs
,
sub
):
return
"""
return
f
"""
Py_XDECREF({out});
Py_XDECREF({out
s[0]
});
{out
} = (PyArrayObject *){inp
};
{out
s[0]} = (PyArrayObject *){inps[0]
};
Py_INCREF({out});
Py_INCREF({out
s[0]
});
"""
.
format
(
**
dict
(
out
=
outs
[
0
],
inp
=
inps
[
0
]))
"""
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
0
,)
return
(
0
,)
...
@@ -136,7 +137,12 @@ class MyOpEnumList(COp):
...
@@ -136,7 +137,12 @@ class MyOpEnumList(COp):
return
(
1
,)
return
(
1
,)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
return
"""
op
=
sub
[
"params"
]
o
=
outputs
[
0
]
a
=
inputs
[
0
]
b
=
inputs
[
1
]
fail
=
sub
[
"fail"
]
return
f
"""
switch({op}) {{
switch({op}) {{
case ADD:
case ADD:
{o} = {a} + {b};
{o} = {a} + {b};
...
@@ -154,15 +160,7 @@ class MyOpEnumList(COp):
...
@@ -154,15 +160,7 @@ class MyOpEnumList(COp):
{{{fail}}}
{{{fail}}}
break;
break;
}}
}}
"""
.
format
(
"""
**
dict
(
op
=
sub
[
"params"
],
o
=
outputs
[
0
],
a
=
inputs
[
0
],
b
=
inputs
[
1
],
fail
=
sub
[
"fail"
],
)
)
class
MyOpCEnumType
(
COp
):
class
MyOpCEnumType
(
COp
):
...
@@ -201,15 +199,10 @@ class MyOpCEnumType(COp):
...
@@ -201,15 +199,10 @@ class MyOpCEnumType(COp):
return
(
3
,)
return
(
3
,)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
return
"""
# params in C code will already contains expected C constant value.
{o} = {val};
return
f
"""
"""
.
format
(
{outputs[0]} = {sub['params']};
**
dict
(
"""
o
=
outputs
[
0
],
# params in C code will already contains expected C constant value.
val
=
sub
[
"params"
],
)
)
class
TestEnumTypes
:
class
TestEnumTypes
:
...
...
tests/tensor/conv/c_conv3d_corr3d_ref.py
浏览文件 @
55b2f4fa
...
@@ -323,7 +323,9 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
...
@@ -323,7 +323,9 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
)
)
depth
=
"-1"
depth
=
"-1"
return
"""
fail
=
sub
[
"fail"
]
params
=
sub
[
"params"
]
return
f
"""
// Mandatory args
// Mandatory args
int direction = {params}->direction; // forward, bprop weights, bprop inputs
int direction = {params}->direction; // forward, bprop weights, bprop inputs
...
@@ -568,18 +570,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
...
@@ -568,18 +570,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
}}
}}
assert (out2 == *out);
assert (out2 == *out);
"""
.
format
(
"""
**
dict
(
bottom
=
bottom
,
weights
=
weights
,
top
=
top
,
height
=
height
,
width
=
width
,
depth
=
depth
,
fail
=
sub
[
"fail"
],
params
=
sub
[
"params"
],
)
)
class
Corr3dMM
(
BaseCorr3dMM
):
class
Corr3dMM
(
BaseCorr3dMM
):
...
...
tests/tensor/conv/c_conv_corr_ref.py
浏览文件 @
55b2f4fa
...
@@ -322,7 +322,9 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
...
@@ -322,7 +322,9 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
)
)
width
=
"-1"
width
=
"-1"
return
"""
fail
=
sub
[
"fail"
]
params
=
sub
[
"params"
]
return
f
"""
// Mandatory args
// Mandatory args
int direction = {params}->direction; // forward, bprop weights, bprop inputs
int direction = {params}->direction; // forward, bprop weights, bprop inputs
...
@@ -614,17 +616,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
...
@@ -614,17 +616,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
}}
}}
assert (out2 == *out);
assert (out2 == *out);
"""
.
format
(
"""
**
dict
(
bottom
=
bottom
,
weights
=
weights
,
top
=
top
,
height
=
height
,
width
=
width
,
fail
=
sub
[
"fail"
],
params
=
sub
[
"params"
],
)
)
class
CorrMM
(
BaseCorrMM
):
class
CorrMM
(
BaseCorrMM
):
...
...
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
55b2f4fa
...
@@ -1391,9 +1391,9 @@ class TimesN(ps.basic.UnaryScalarOp):
...
@@ -1391,9 +1391,9 @@ class TimesN(ps.basic.UnaryScalarOp):
def
c_support_code_apply
(
self
,
node
,
nodename
):
def
c_support_code_apply
(
self
,
node
,
nodename
):
n
=
str
(
self
.
n
)
n
=
str
(
self
.
n
)
return
"""
return
f
"""
float {nodename}_timesn(float x) {{ return x * {n}; }}
float {nodename}_timesn(float x) {{ return x * {n}; }}
"""
.
format
(
**
locals
())
"""
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
(
x
,)
=
inputs
(
x
,)
=
inputs
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论