Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
55b2f4fa
提交
55b2f4fa
authored
7月 07, 2024
作者:
Virgile Andreani
提交者:
Virgile Andreani
7月 12, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace str.format with f-strings
上级
e2f9cb8e
显示空白字符变更
内嵌
并排
正在显示
39 个修改的文件
包含
355 行增加
和
456 行删除
+355
-456
builders.py
pytensor/compile/builders.py
+1
-1
types.py
pytensor/compile/function/types.py
+1
-5
basic.py
pytensor/graph/rewriting/basic.py
+5
-8
utils.py
pytensor/graph/utils.py
+3
-4
basic.py
pytensor/link/c/basic.py
+18
-22
cmodule.py
pytensor/link/c/cmodule.py
+2
-2
interface.py
pytensor/link/c/interface.py
+4
-8
op.py
pytensor/link/c/op.py
+4
-13
params_type.py
pytensor/link/c/params_type.py
+26
-42
type.py
pytensor/link/c/type.py
+33
-45
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+4
-5
utils.py
pytensor/link/utils.py
+3
-7
printing.py
pytensor/printing.py
+2
-4
basic.py
pytensor/scalar/basic.py
+15
-17
math.py
pytensor/scalar/math.py
+10
-10
utils.py
pytensor/scan/utils.py
+2
-1
basic.py
pytensor/sparse/basic.py
+6
-4
rewriting.py
pytensor/sparse/rewriting.py
+29
-20
basic.py
pytensor/tensor/basic.py
+18
-24
blas.py
pytensor/tensor/blas.py
+9
-16
blas_c.py
pytensor/tensor/blas_c.py
+2
-2
blas_headers.py
pytensor/tensor/blas_headers.py
+2
-2
elemwise.py
pytensor/tensor/elemwise.py
+19
-24
elemwise_cgen.py
pytensor/tensor/elemwise_cgen.py
+4
-2
extra_ops.py
pytensor/tensor/extra_ops.py
+8
-9
math.py
pytensor/tensor/math.py
+4
-3
subtensor.py
pytensor/tensor/subtensor.py
+22
-30
type.py
pytensor/tensor/type.py
+17
-15
basic.py
pytensor/typed_list/basic.py
+20
-21
type.py
pytensor/typed_list/type.py
+7
-6
test_debugmode.py
tests/compile/test_debugmode.py
+6
-4
test_basic.py
tests/link/c/test_basic.py
+3
-2
test_cmodule.py
tests/link/c/test_cmodule.py
+2
-4
test_op.py
tests/link/c/test_op.py
+3
-3
test_params_type.py
tests/link/c/test_params_type.py
+10
-16
test_type.py
tests/link/c/test_type.py
+21
-28
c_conv3d_corr3d_ref.py
tests/tensor/conv/c_conv3d_corr3d_ref.py
+4
-13
c_conv_corr_ref.py
tests/tensor/conv/c_conv_corr_ref.py
+4
-12
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+2
-2
没有找到文件。
pytensor/compile/builders.py
浏览文件 @
55b2f4fa
...
...
@@ -434,7 +434,7 @@ class OpFromGraph(Op, HasInnerGraph):
def
__str__
(
self
):
name
=
self
.
__class__
.
__name__
if
self
.
name
is
None
else
self
.
name
is_inline
=
self
.
is_inline
return
"{name}{{inline={is_inline}}}"
.
format
(
**
locals
())
return
f
"{name}{{inline={is_inline}}}"
def
_combine_list_overrides
(
self
,
default_outs
,
custom_outs
,
callable_args
):
"""Combines default and custom overrides into a single list of outputs."""
...
...
pytensor/compile/function/types.py
浏览文件 @
55b2f4fa
...
...
@@ -1890,11 +1890,7 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs):
)
else
:
if
n_unnamed_inputs
==
0
:
msg
=
"The function has {} named input{} ({})."
.
format
(
n_named_inputs
,
get_plural
(
n_named_inputs
),
", "
.
join
(
named_inputs
),
)
msg
=
f
"The function has {n_named_inputs} named input{get_plural(n_named_inputs)} ({', '.join(named_inputs)})."
else
:
msg
=
(
f
"The function has {n_named_inputs} named input{get_plural(n_named_inputs)} ({', '.join(named_inputs)}), and {n_unnamed_inputs} unnamed "
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
55b2f4fa
...
...
@@ -1664,15 +1664,12 @@ class PatternNodeRewriter(NodeRewriter):
def
pattern_to_str
(
pattern
):
if
isinstance
(
pattern
,
list
|
tuple
):
return
"{}({})"
.
format
(
str
(
pattern
[
0
]),
", "
.
join
(
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]),
)
args
=
", "
.
join
(
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:])
return
f
"{pattern[0]!s}({args})"
elif
isinstance
(
pattern
,
dict
):
return
"{} subject to {}"
.
format
(
pattern_to_str
(
pattern
[
"pattern"
]),
str
(
pattern
.
get
(
"constraint"
,
"no conditions"
)),
)
a
=
pattern_to_str
(
pattern
[
"pattern"
])
b
=
pattern
.
get
(
"constraint"
,
"no conditions"
)
return
f
"{a} subject to {b}"
else
:
return
str
(
pattern
)
...
...
pytensor/graph/utils.py
浏览文件 @
55b2f4fa
...
...
@@ -245,10 +245,9 @@ class MetaType(ABCMeta):
else
:
def
__str__
(
self
):
return
"{}{{{}}}"
.
format
(
self
.
__class__
.
__name__
,
", "
.
join
(
f
"{p}={getattr(self, p)!r}"
for
p
in
props
),
)
classname
=
self
.
__class__
.
__name__
args
=
", "
.
join
(
f
"{p}={getattr(self, p)!r}"
for
p
in
props
)
return
f
"{classname}{{{args}}}"
dct
[
"__str__"
]
=
__str__
...
...
pytensor/link/c/basic.py
浏览文件 @
55b2f4fa
...
...
@@ -219,7 +219,6 @@ def struct_gen(args, struct_builders, blocks, sub):
"""
struct_decl
=
""
struct_init_head
=
""
struct_init_tail
=
""
struct_cleanup
=
""
for
block
in
struct_builders
:
...
...
@@ -243,7 +242,6 @@ def struct_gen(args, struct_builders, blocks, sub):
# decrements the storage's refcount in the destructor
storage_decref
=
"
\n
"
.
join
(
f
"Py_XDECREF(this->{arg});"
for
arg
in
args
)
args_names
=
", "
.
join
(
args
)
args_decl
=
", "
.
join
(
f
"PyObject* {arg}"
for
arg
in
args
)
# The following code stores the exception data in __ERROR, which
...
...
@@ -251,7 +249,8 @@ def struct_gen(args, struct_builders, blocks, sub):
# that holds the type, the value and the traceback. After storing
# the error, we return the failure code so we know which code
# block failed.
do_return
=
"""
failure_var
=
sub
[
"failure_var"
]
do_return
=
f
"""
if ({failure_var}) {{
// When there is a failure, this code puts the exception
// in __ERROR.
...
...
@@ -274,15 +273,13 @@ def struct_gen(args, struct_builders, blocks, sub):
}}
// The failure code is returned to index what code block failed.
return {failure_var};
"""
.
format
(
**
sub
)
sub
=
dict
(
sub
)
sub
.
update
(
locals
())
"""
# TODO: add some error checking to make sure storage_<x> are
# 1-element lists and __ERROR is a 3-elements list.
struct_code
=
"""
name
=
sub
[
"name"
]
struct_code
=
f
"""
namespace {{
struct {name} {{
PyObject* __ERROR;
...
...
@@ -326,7 +323,7 @@ def struct_gen(args, struct_builders, blocks, sub):
}}
}};
}}
"""
.
format
(
**
sub
)
"""
return
struct_code
...
...
@@ -370,13 +367,10 @@ def get_c_init(fgraph, r, name, sub):
Wrapper around c_init that initializes py_name to Py_None.
"""
pre
=
(
""
"""
pre
=
f
"""
py_{name} = Py_None;
{{Py_XINCREF(py_{name});}}
"""
.
format
(
**
locals
())
)
"""
return
pre
+
r
.
type
.
c_init
(
name
,
sub
)
...
...
@@ -410,10 +404,10 @@ def get_c_extract(fgraph, r, name, sub):
else
:
c_extract
=
r
.
type
.
c_extract
(
name
,
sub
,
False
)
pre
=
"""
pre
=
f
"""
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
"""
.
format
(
**
locals
())
"""
return
pre
+
c_extract
...
...
@@ -439,10 +433,10 @@ def get_c_extract_out(fgraph, r, name, sub):
else
:
c_extract
=
r
.
type
.
c_extract_out
(
name
,
sub
,
check_input
,
check_broadcast
=
False
)
pre
=
"""
pre
=
f
"""
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
"""
.
format
(
**
locals
())
"""
return
pre
+
c_extract
...
...
@@ -451,9 +445,9 @@ def get_c_cleanup(fgraph, r, name, sub):
Wrapper around c_cleanup that decrefs py_name.
"""
post
=
"""
post
=
f
"""
{{Py_XDECREF(py_{name});}}
"""
.
format
(
**
locals
())
"""
return
r
.
type
.
c_cleanup
(
name
,
sub
)
+
post
...
...
@@ -462,7 +456,9 @@ def get_c_sync(fgraph, r, name, sub):
Wrapper around c_sync that syncs py_name with storage.
"""
return
"""
failure_var
=
sub
[
"failure_var"
]
sync
=
r
.
type
.
c_sync
(
name
,
sub
)
return
f
"""
if (!{failure_var}) {{
{sync}
PyObject* old = PyList_GET_ITEM(storage_{name}, 0);
...
...
@@ -470,7 +466,7 @@ def get_c_sync(fgraph, r, name, sub):
PyList_SET_ITEM(storage_{name}, 0, py_{name});
{{Py_XDECREF(old);}}
}}
"""
.
format
(
**
dict
(
sync
=
r
.
type
.
c_sync
(
name
,
sub
),
name
=
name
,
**
sub
))
"""
def
apply_policy
(
fgraph
,
policy
,
r
,
name
,
sub
):
...
...
pytensor/link/c/cmodule.py
浏览文件 @
55b2f4fa
...
...
@@ -1966,14 +1966,14 @@ class Compiler:
return
False
code
=
(
"""
f
"""
{preamble}
int main(int argc, char** argv)
{{
{body}
return 0;
}}
"""
.
format
(
**
locals
())
"""
)
.
encode
()
return
cls
.
_try_compile_tmp
(
code
,
...
...
pytensor/link/c/interface.py
浏览文件 @
55b2f4fa
...
...
@@ -558,7 +558,9 @@ class CLinkerType(CLinkerObject):
uninitialized.
"""
return
"""
c_init_code
=
self
.
c_init
(
name
,
sub
)
c_extract_code
=
self
.
c_extract
(
name
,
sub
,
check_input
)
return
f
"""
if (py_{name} == Py_None)
{{
{c_init_code}
...
...
@@ -567,13 +569,7 @@ class CLinkerType(CLinkerObject):
{{
{c_extract_code}
}}
"""
.
format
(
**
dict
(
name
=
name
,
c_init_code
=
self
.
c_init
(
name
,
sub
),
c_extract_code
=
self
.
c_extract
(
name
,
sub
,
check_input
),
)
)
"""
def
c_cleanup
(
self
,
name
:
str
,
sub
:
dict
[
str
,
str
])
->
str
:
"""Return C code to clean up after :meth:`CLinkerType.c_extract`.
...
...
pytensor/link/c/op.py
浏览文件 @
55b2f4fa
...
...
@@ -587,24 +587,15 @@ class ExternalCOp(COp):
params
=
f
", {sub['params']}"
# Generate the C code
return
"""
return
f
"""
{define_macros}
{{
if ({
func_name}({func_args
}{params}) != 0) {{
{
fail
}
if ({
self.func_name}({self.format_c_function_args(inp, out)
}{params}) != 0) {{
{
sub['fail']
}
}}
}}
{undef_macros}
"""
.
format
(
**
dict
(
func_name
=
self
.
func_name
,
fail
=
sub
[
"fail"
],
params
=
params
,
func_args
=
self
.
format_c_function_args
(
inp
,
out
),
define_macros
=
define_macros
,
undef_macros
=
undef_macros
,
)
)
"""
else
:
if
"code"
in
self
.
code_sections
:
op_code
=
self
.
code_sections
[
"code"
]
...
...
pytensor/link/c/params_type.py
浏览文件 @
55b2f4fa
...
...
@@ -262,9 +262,10 @@ class Params(dict):
self
.
__dict__
.
update
(
__params_type__
=
params_type
,
__signatures__
=
None
)
def
__repr__
(
self
):
return
"Params({})"
.
format
(
", "
.
join
((
f
"{k}:{type(self[k]).__name__}:{self[k]}"
)
for
k
in
sorted
(
self
)
)
args
=
", "
.
join
(
(
f
"{k}:{type(self[k]).__name__}:{self[k]}"
)
for
k
in
sorted
(
self
)
)
return
f
"Params({args})"
def
__getattr__
(
self
,
key
):
if
key
not
in
self
:
...
...
@@ -422,9 +423,10 @@ class ParamsType(CType):
return
super
()
.
__getattr__
(
self
,
key
)
def
__repr__
(
self
):
return
"ParamsType<{}>"
.
format
(
", "
.
join
((
f
"{self.fields[i]}:{self.types[i]}"
)
for
i
in
range
(
self
.
length
)
)
args
=
", "
.
join
(
f
"{self.fields[i]}:{self.types[i]}"
for
i
in
range
(
self
.
length
)
)
return
f
"ParamsType<{args}>"
def
__eq__
(
self
,
other
):
return
(
...
...
@@ -730,11 +732,15 @@ class ParamsType(CType):
struct_init
=
"
\n
"
.
join
(
c_init_list
)
struct_cleanup
=
"
\n
"
.
join
(
c_cleanup_list
)
struct_extract
=
"
\n\n
"
.
join
(
c_extract_list
)
struct_extract_method
=
"""
args
=
"
\n
"
.
join
(
f
"case {i}: extract_{self.fields[i]}(object); break;"
for
i
in
range
(
self
.
length
)
)
struct_extract_method
=
f
"""
void extract(PyObject* object, int field_pos) {{
switch(field_pos) {{
// Extraction cases.
{}
{
args
}
// Default case.
default:
PyErr_Format(PyExc_TypeError, "ParamsType: no extraction defined for a field
%
d.", field_pos);
...
...
@@ -742,13 +748,8 @@ class ParamsType(CType):
break;
}}
}}
"""
.
format
(
"
\n
"
.
join
(
(
"case
%
d: extract_
%
s(object); break;"
%
(
i
,
self
.
fields
[
i
]))
for
i
in
range
(
self
.
length
)
)
)
final_struct_code
=
"""
"""
final_struct_code
=
f
"""
/** ParamsType {struct_name} **/
#ifndef {struct_name_defined}
#define {struct_name_defined}
...
...
@@ -790,17 +791,7 @@ class ParamsType(CType):
}};
#endif
/** End ParamsType {struct_name} **/
"""
.
format
(
**
dict
(
struct_name_defined
=
struct_name_defined
,
struct_name
=
struct_name
,
struct_declare
=
struct_declare
,
struct_init
=
struct_init
,
struct_cleanup
=
struct_cleanup
,
struct_extract
=
struct_extract
,
struct_extract_method
=
struct_extract_method
,
)
)
"""
return
[
*
sorted
(
c_support_code_set
),
final_struct_code
]
...
...
@@ -813,9 +804,9 @@ class ParamsType(CType):
# pointers.
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
return
"""
{s
truct_
name}* {name};
"""
.
format
(
**
dict
(
struct_name
=
self
.
name
,
name
=
name
))
return
f
"""
{s
elf.
name}* {name};
"""
def
c_init
(
self
,
name
,
sub
):
# NB: It seems c_init() is not called for an op param.
...
...
@@ -831,40 +822,33 @@ class ParamsType(CType):
"""
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
fields_list
=
", "
.
join
(
f
'"{x}"'
for
x
in
self
.
fields
)
return
f
"""
/* Seems c_init() is not called for a op param. So I call `new` here. */
{name} = new {s
truct_
name};
{name} = new {s
elf.
name};
{{ // This need a separate namespace for Clinker
const char* fields[] = {{{fields_list}}};
if (py_{name} == Py_None) {{
PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None.");
{
fail
}
{
sub['fail']
}
}}
for (int i = 0; i < {length}; ++i) {{
for (int i = 0; i < {
self.
length}; ++i) {{
PyObject* o = PyDict_GetItemString(py_{name}, fields[i]);
if (o == NULL) {{
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute
\\
"
%
s
\\
" in object.", fields[i]);
{
fail
}
{
sub['fail']
}
}}
{name}->extract(o, i);
if ({name}->errorOccurred()) {{
/* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */
fprintf(stderr, "
\\
nParamsType: error when extracting value for attribute
\\
"
%
s
\\
".
\\
n", fields[i]);
{
fail
}
{
sub['fail']
}
}}
}}
}}
"""
.
format
(
**
dict
(
name
=
name
,
struct_name
=
self
.
name
,
length
=
self
.
length
,
fail
=
sub
[
"fail"
],
fields_list
=
'"{}"'
.
format
(
'", "'
.
join
(
self
.
fields
)),
)
)
"""
def
c_sync
(
self
,
name
,
sub
):
# FIXME: Looks like we need to decrement a reference count our two.
...
...
pytensor/link/c/type.py
浏览文件 @
55b2f4fa
...
...
@@ -98,12 +98,12 @@ class Generic(CType, Singleton):
"""
def
c_sync
(
self
,
name
,
sub
):
return
"""
return
f
"""
assert(py_{name}->ob_refcnt > 1);
Py_DECREF(py_{name});
py_{name} = {name} ? {name} : Py_None;
Py_INCREF(py_{name});
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
return
(
1
,)
...
...
@@ -190,18 +190,19 @@ class CDataType(CType[D]):
return
data
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
return
"""
{ctype} {name};
"""
.
format
(
**
dict
(
ctype
=
self
.
ctype
,
name
=
name
))
return
f
"""
{
self.
ctype} {name};
"""
def
c_init
(
self
,
name
,
sub
):
return
f
"{name} = NULL;"
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
{name} = ({ctype})PyCapsule_GetPointer(py_{name}, NULL);
fail
=
sub
[
"fail"
]
return
f
"""
{name} = ({self.ctype})PyCapsule_GetPointer(py_{name}, NULL);
if ({name} == NULL) {fail}
"""
.
format
(
**
dict
(
name
=
name
,
ctype
=
self
.
ctype
,
fail
=
sub
[
"fail"
]))
"""
def
c_sync
(
self
,
name
,
sub
):
freefunc
=
self
.
freefunc
...
...
@@ -478,11 +479,8 @@ class EnumType(CType, dict):
names_to_aliases
=
{
constant_name
:
""
for
constant_name
in
self
}
for
alias
in
self
.
aliases
:
names_to_aliases
[
self
.
aliases
[
alias
]]
=
f
"({alias})"
return
"{}<{}>({})"
.
format
(
type
(
self
)
.
__name__
,
self
.
ctype
,
", "
.
join
(
f
"{k}{names_to_aliases[k]}:{self[k]}"
for
k
in
sorted
(
self
)),
)
args
=
", "
.
join
(
f
"{k}{names_to_aliases[k]}:{self[k]}"
for
k
in
sorted
(
self
))
return
f
"{type(self).__name__}<{self.ctype}>({args})"
def
__getattr__
(
self
,
key
):
if
key
in
self
:
...
...
@@ -575,33 +573,27 @@ class EnumType(CType, dict):
This C function may be useful to retrieve some runtime information.
It is available in C code when pytensor flag ``config.cmodule__debug`` is set to ``True``.
"""
return
"""
cases
=
""
.
join
(
f
"""
case {name}: sprintf(out, "{name}"); break;
"""
for
name
in
self
)
return
f
"""
#ifdef DEBUG
int pytensor_enum_to_string_{
cname}({
ctype} in, char* out) {{
int pytensor_enum_to_string_{
self.cname}({self.
ctype} in, char* out) {{
int ret = 0;
switch(in) {{
{cases}
default:
PyErr_SetString(PyExc_ValueError, "{
classname
}: unknown enum value.");
PyErr_SetString(PyExc_ValueError, "{
type(self).__name__
}: unknown enum value.");
ret = -1;
break;
}}
return ret;
}}
#endif
"""
.
format
(
**
dict
(
cname
=
self
.
cname
,
ctype
=
self
.
ctype
,
classname
=
type
(
self
)
.
__name__
,
cases
=
""
.
join
(
"""
case {name}: sprintf(out, "{name}"); break;
"""
.
format
(
**
dict
(
name
=
name
))
for
name
in
self
),
)
)
def
c_support_code
(
self
,
**
kwargs
):
return
(
...
...
@@ -625,16 +617,17 @@ class EnumType(CType, dict):
return
""
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyInt_Check(py_{name})) {{
{name} = ({ctype})PyInt_AsLong(py_{name});
{name} = ({
self.
ctype})PyInt_AsLong(py_{name});
}} else {{
{name} = ({ctype})PyFloat_AsDouble(py_{name});
{name} = ({
self.
ctype})PyFloat_AsDouble(py_{name});
}}
if (PyErr_Occurred()) {{
{fail}
}}
"""
.
format
(
**
dict
(
ctype
=
self
.
ctype
,
name
=
name
,
fail
=
sub
[
"fail"
]))
"""
def
c_code_cache_version
(
self
):
return
(
2
,
self
.
ctype
,
self
.
cname
,
tuple
(
self
.
items
()))
...
...
@@ -754,7 +747,14 @@ class CEnumType(EnumList):
swapped_dict
=
{
v
:
k
for
(
k
,
v
)
in
self
.
items
()}
# swapped_dict's keys are integers.
return
"""
fail
=
sub
[
"fail"
]
cases
=
""
.
join
(
f
"""
case {i}: {name} = {swapped_dict[i]}; break;
"""
for
i
in
sorted
(
swapped_dict
)
)
return
f
"""
switch(PyInt_AsLong(py_{name})) {{
{cases}
default:
...
...
@@ -762,19 +762,7 @@ class CEnumType(EnumList):
{{{fail}}}
break;
}}
"""
.
format
(
**
dict
(
name
=
name
,
cases
=
""
.
join
(
"""
case
%(i)
d:
%(name)
s =
%(constant_cname)
s; break;
"""
%
dict
(
i
=
i
,
name
=
name
,
constant_cname
=
swapped_dict
[
i
])
for
i
in
sorted
(
swapped_dict
)
),
fail
=
sub
[
"fail"
],
)
)
def
c_code_cache_version
(
self
):
return
(
1
,
super
()
.
c_code_cache_version
())
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
55b2f4fa
...
...
@@ -136,20 +136,19 @@ def _check_scipy_linalg_matrix(a, func_name):
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
"""
prefix
=
"scipy.linalg"
interp
=
(
prefix
,
func_name
)
# Unpack optional type
if
isinstance
(
a
,
types
.
Optional
):
a
=
a
.
type
if
not
isinstance
(
a
,
types
.
Array
):
msg
=
"{}.{}() only supported for array types"
.
format
(
*
interp
)
msg
=
f
"{prefix}.{func_name}() only supported for array types"
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
a
.
ndim
not
in
[
1
,
2
]:
msg
=
"{}.{}() only supported on 1d or 2d arrays, found {}."
.
format
(
*
interp
,
a
.
ndim
msg
=
(
f
"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}."
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
not
isinstance
(
a
.
dtype
,
types
.
Float
|
types
.
Complex
):
msg
=
"{
}.{}() only supported on "
"float and complex arrays."
.
format
(
*
interp
)
msg
=
"{
prefix}.{func_name}() only supported on float and complex arrays."
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
...
...
pytensor/link/utils.py
浏览文件 @
55b2f4fa
...
...
@@ -355,14 +355,10 @@ def raise_with_op(
+
f
"
\n
Inputs values: {scalar_values}"
)
if
verbosity
==
"high"
:
detailed_err_msg
+=
"
\n
Inputs type_num: {}"
.
format
(
str
(
[
getattr
(
getattr
(
i
[
0
],
"dtype"
,
""
),
"num"
,
""
)
for
i
in
thunk
.
inputs
inpts
=
[
getattr
(
getattr
(
i
[
0
],
"dtype"
,
""
),
"num"
,
""
)
for
i
in
thunk
.
inputs
]
)
)
detailed_err_msg
+=
f
"
\n
Inputs type_num: {inpts}"
detailed_err_msg
+=
f
"
\n
Outputs clients: {clients}
\n
"
else
:
...
...
pytensor/printing.py
浏览文件 @
55b2f4fa
...
...
@@ -1048,10 +1048,8 @@ class DefaultPrinter(Printer):
if
node
is
None
:
return
leaf_printer
.
process
(
output
,
pstate
)
with
set_precedence
(
pstate
):
r
=
"{}({})"
.
format
(
str
(
node
.
op
),
", "
.
join
(
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
),
)
args
=
", "
.
join
(
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
)
r
=
f
"{node.op}({args})"
pstate
.
memo
[
output
]
=
r
return
r
...
...
pytensor/scalar/basic.py
浏览文件 @
55b2f4fa
...
...
@@ -464,33 +464,32 @@ class ScalarType(CType, HasDataType, HasShape):
raise
NotImplementedError
(
"float16"
)
specs
=
self
.
dtype_specs
()
if
check_input
:
pre
=
"""
fail
=
sub
[
"fail"
]
dtype
=
specs
[
1
]
pyarr_type
=
f
"Py{specs[2]}ArrType_Type"
pre
=
f
"""
if (!PyObject_TypeCheck(py_{name}, &{pyarr_type}))
{{
PyErr_Format(PyExc_ValueError,
"Scalar check failed ({dtype})");
{fail}
}}
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
pyarr_type
=
f
"Py{specs[2]}ArrType_Type"
,
)
)
"""
else
:
pre
=
""
return
(
pre
+
"""
+
f
"""
PyArray_ScalarAsCtype(py_{name}, &{name});
"""
.
format
(
**
dict
(
sub
,
name
=
name
))
"""
)
def
c_sync
(
self
,
name
,
sub
):
specs
=
self
.
dtype_specs
()
return
"""
fail
=
sub
[
"fail"
]
dtype
=
specs
[
1
]
cls
=
specs
[
2
]
return
f
"""
Py_XDECREF(py_{name});
py_{name} = PyArrayScalar_New({cls});
if (!py_{name})
...
...
@@ -502,7 +501,7 @@ class ScalarType(CType, HasDataType, HasShape):
{fail}
}}
PyArrayScalar_ASSIGN(py_{name}, {cls}, {name});
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
specs
[
1
],
cls
=
specs
[
2
]))
"""
def
c_cleanup
(
self
,
name
,
sub
):
return
""
...
...
@@ -1179,10 +1178,9 @@ class ScalarOp(COp):
not
in
(
"name"
,
"_op_use_c_code"
,
"bool"
,
"output_types_preference"
)
]
if
param
:
return
"{}{{{}}}"
.
format
(
self
.
__class__
.
__name__
,
", "
.
join
(
f
"{k}={v}"
for
k
,
v
in
param
),
)
classname
=
self
.
__class__
.
__name__
args
=
", "
.
join
(
f
"{k}={v}"
for
k
,
v
in
param
)
return
f
"{classname}{{{args}}}"
else
:
return
self
.
__class__
.
__name__
...
...
pytensor/scalar/math.py
浏览文件 @
55b2f4fa
...
...
@@ -615,8 +615,8 @@ class Chi2SF(BinaryScalarOp):
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""{z} =
({dtype}) 1 - GammaP({k}/2., {x}/2.);"""
.
format
(
**
locals
())
return
f
"""{z} =
({dtype}) 1 - GammaP({k}/2., {x}/2.);"""
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
...
...
@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp):
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""{z} =
({dtype}) GammaP({k}, {x});"""
.
format
(
**
locals
())
return
f
"""{z} =
({dtype}) GammaP({k}, {x});"""
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
...
...
@@ -717,8 +717,8 @@ class GammaIncC(BinaryScalarOp):
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""{z} =
({dtype}) GammaQ({k}, {x});"""
.
format
(
**
locals
())
return
f
"""{z} =
({dtype}) GammaQ({k}, {x});"""
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
...
...
@@ -1028,8 +1028,8 @@ class GammaU(BinaryScalarOp):
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""{z} =
({dtype}) upperGamma({k}, {x});"""
.
format
(
**
locals
())
return
f
"""{z} =
({dtype}) upperGamma({k}, {x});"""
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
...
...
@@ -1064,8 +1064,8 @@ class GammaL(BinaryScalarOp):
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
dtype
=
"npy_"
+
node
.
outputs
[
0
]
.
dtype
return
"""{z} =
({dtype}) lowerGamma({k}, {x});"""
.
format
(
**
locals
())
return
f
"""{z} =
({dtype}) lowerGamma({k}, {x});"""
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
...
...
pytensor/scan/utils.py
浏览文件 @
55b2f4fa
...
...
@@ -1072,7 +1072,8 @@ class ScanArgs:
for
p
in
self
.
field_names
if
p
.
startswith
(
"outer_out"
)
]
res
=
"ScanArgs(
\n
{})"
.
format
(
",
\n
"
.
join
(
inner_arg_strs
))
args
=
",
\n
"
.
join
(
inner_arg_strs
)
res
=
f
"ScanArgs(
\n
{args})"
return
res
def
__repr__
(
self
):
...
...
pytensor/sparse/basic.py
浏览文件 @
55b2f4fa
...
...
@@ -3617,7 +3617,8 @@ class StructuredDotGradCSC(COp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for g_ab"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}}
if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}}
if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}}
...
...
@@ -3689,7 +3690,7 @@ class StructuredDotGradCSC(COp):
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
0
]]
...
...
@@ -3750,7 +3751,8 @@ class StructuredDotGradCSR(COp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for g_ab"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}}
if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}}
if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}}
...
...
@@ -3823,7 +3825,7 @@ class StructuredDotGradCSR(COp):
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
0
]]
...
...
pytensor/sparse/rewriting.py
浏览文件 @
55b2f4fa
...
...
@@ -137,7 +137,7 @@ class AddSD_ccode(_NoPythonCOp):
inplace
=
int
(
self
.
inplace
)
format
=
{
"csc"
:
0
,
"csr"
:
1
}[
self
.
format
]
out_typenum
=
node
.
outputs
[
0
]
.
type
.
dtype_specs
()[
2
]
code
=
"""
code
=
f
"""
Py_XDECREF({z});
if (!{inplace}){{
if(PyArray_TYPE({y}) != {out_typenum}){{
...
...
@@ -179,7 +179,7 @@ class AddSD_ccode(_NoPythonCOp):
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
return
code
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
...
...
@@ -310,7 +310,8 @@ class StructuredDotCSC(COp):
typenum_a_val
=
node
.
inputs
[
0
]
.
type
.
dtype_specs
()[
2
]
# retrieve dtype number
typenum_b
=
node
.
inputs
[
4
]
.
type
.
dtype_specs
()[
2
]
# retrieve dtype number
rval
=
"""
fail
=
sub
[
"fail"
]
rval
=
f
"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
...
...
@@ -430,7 +431,7 @@ class StructuredDotCSC(COp):
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
return
rval
...
...
@@ -517,7 +518,8 @@ class StructuredDotCSR(COp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}}
...
...
@@ -609,7 +611,7 @@ class StructuredDotCSR(COp):
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
c_code_cache_version
(
self
):
return
(
2
,)
...
...
@@ -756,7 +758,8 @@ class UsmmCscDense(_NoPythonCOp):
inplace
=
int
(
self
.
inplace
)
rval
=
"""
fail
=
sub
[
"fail"
]
rval
=
f
"""
if (PyArray_NDIM({x_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_val) != 1"); {fail};}}
if (PyArray_NDIM({x_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_ind) != 1"); {fail};}}
...
...
@@ -888,7 +891,7 @@ class UsmmCscDense(_NoPythonCOp):
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
return
rval
...
...
@@ -985,7 +988,8 @@ class CSMGradC(_NoPythonCOp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b_val"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}}
...
...
@@ -1079,7 +1083,7 @@ class CSMGradC(_NoPythonCOp):
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
c_code_cache_version
(
self
):
return
(
3
,)
...
...
@@ -1165,7 +1169,8 @@ class MulSDCSC(_NoPythonCOp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_b}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
{fail};}}
...
...
@@ -1231,7 +1236,7 @@ class MulSDCSC(_NoPythonCOp):
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
__str__
(
self
):
return
self
.
__class__
.
__name__
...
...
@@ -1302,7 +1307,8 @@ class MulSDCSR(_NoPythonCOp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_b}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
{fail};}}
...
...
@@ -1368,7 +1374,7 @@ class MulSDCSR(_NoPythonCOp):
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
__str__
(
self
):
return
self
.
__class__
.
__name__
...
...
@@ -1491,7 +1497,8 @@ class MulSVCSR(_NoPythonCOp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_b}) != 1) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
{fail};
...
...
@@ -1553,7 +1560,7 @@ class MulSVCSR(_NoPythonCOp):
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
__str__
(
self
):
return
self
.
__class__
.
__name__
...
...
@@ -1663,7 +1670,8 @@ class StructuredAddSVCSR(_NoPythonCOp):
if
node
.
inputs
[
3
]
.
type
.
dtype
in
(
"complex64"
,
"complex128"
):
raise
NotImplementedError
(
"Complex types are not supported for b"
)
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({_b}) != 1) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
{fail};
...
...
@@ -1732,7 +1740,7 @@ class StructuredAddSVCSR(_NoPythonCOp):
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
__str__
(
self
):
return
self
.
__class__
.
__name__
...
...
@@ -1904,7 +1912,8 @@ class SamplingDotCSR(_NoPythonCOp):
typenum_zi
=
TensorType
(
node
.
outputs
[
1
]
.
dtype
,
[])
.
dtype_specs
()[
2
]
typenum_zp
=
TensorType
(
node
.
outputs
[
2
]
.
dtype
,
[])
.
dtype_specs
()[
2
]
rval
=
"""
fail
=
sub
[
"fail"
]
rval
=
f
"""
if (PyArray_NDIM({x}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); {fail};}}
if (PyArray_NDIM({y}) != 2) {{
...
...
@@ -2024,7 +2033,7 @@ PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); {fail};}}
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
return
rval
...
...
pytensor/tensor/basic.py
浏览文件 @
55b2f4fa
...
...
@@ -597,12 +597,12 @@ class TensorFromScalar(COp):
(
z
,)
=
outputs
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
{z} = (PyArrayObject*)PyArray_FromScalar(py_{x}, NULL);
if({z} == NULL){{
{fail};
}}
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
return
(
2
,)
...
...
@@ -646,10 +646,9 @@ class ScalarFromTensor(COp):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
(
x
,)
=
inputs
(
z
,)
=
outputs
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
{z} = ((dtype_{x}*)(PyArray_DATA({x})))[0];
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
return
(
1
,)
...
...
@@ -1836,18 +1835,18 @@ class MakeVector(COp):
assert
self
.
dtype
==
node
.
inputs
[
0
]
.
dtype
out_num
=
f
"PyArray_TYPE({inp[0]})"
ret
=
"""
ret
=
f
"""
npy_intp dims[1];
dims[0] = {out_shape};
if(!{out} || PyArray_DIMS({out})[0] != {out_shape}){{
Py_XDECREF({out});
{out} = (PyArrayObject*)PyArray_EMPTY(1, dims, {out_num}, 0);
}}
"""
.
format
(
**
locals
())
"""
for
idx
,
i
in
enumerate
(
inp
):
ret
+=
"""
ret
+=
f
"""
*(({out_dtype} *)PyArray_GETPTR1({out}, {idx})) = *(({out_dtype} *) PyArray_DATA({i}));
"""
.
format
(
**
locals
())
"""
return
ret
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
...
...
@@ -2225,7 +2224,7 @@ class Split(COp):
splits_dtype
=
node
.
inputs
[
2
]
.
type
.
dtype_specs
()[
1
]
expected_splits_count
=
self
.
len_splits
return
"""
return
f
"""
int ndim = PyArray_NDIM({x});
int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0));
int splits_count = PyArray_DIM({splits}, 0);
...
...
@@ -2322,7 +2321,7 @@ class Split(COp):
}}
free(split_dims);
"""
.
format
(
**
locals
())
"""
class
Join
(
COp
):
...
...
@@ -2373,10 +2372,9 @@ class Join(COp):
if
self
.
view
==
-
1
:
return
self
.
__class__
.
__name__
else
:
return
"{}{{{}}}"
.
format
(
self
.
__class__
.
__name__
,
", "
.
join
(
f
"{p}={getattr(self, p)!r}"
for
p
in
self
.
__props__
),
)
classname
=
self
.
__class__
.
__name__
args
=
", "
.
join
(
f
"{p}={getattr(self, p)!r}"
for
p
in
self
.
__props__
)
return
f
"{classname}{{{args}}}"
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
...
...
@@ -2538,7 +2536,7 @@ class Join(COp):
copy_inputs_to_list
=
"
\n
"
.
join
(
copy_to_list
)
n
=
len
(
tens
)
code
=
"""
code
=
f
"""
int axis = (({adtype} *)PyArray_DATA({axis}))[0];
PyObject* list = PyList_New({l});
{copy_inputs_to_list}
...
...
@@ -2570,7 +2568,7 @@ class Join(COp):
{fail}
}}
}}
"""
.
format
(
**
locals
())
"""
return
code
def
R_op
(
self
,
inputs
,
eval_points
):
...
...
@@ -4192,18 +4190,14 @@ class AllocEmpty(COp):
params
=
sub
[
"params"
]
str
=
f
"npy_intp dims[{nd}];
\n
"
for
idx
,
sh
in
enumerate
(
shps
):
str
+=
(
"dims[{idx}] ="
"((npy_intp)((dtype_{sh}*)"
" PyArray_DATA({sh}))[0]);
\n
"
.
format
(
**
locals
())
)
str
+=
f
"dims[{idx}] = ((npy_intp)((dtype_{sh}*) PyArray_DATA({sh}))[0]);
\n
"
# Validate that the output storage exists
str
+=
f
"if({out}==NULL
\n
"
for
idx
,
sh
in
enumerate
(
shps
):
str
+=
f
"||PyArray_DIMS({out})[{idx}]!=dims[{idx}]"
str
+=
"""){{
str
+=
f
"""){{
/* Reference received to invalid output variable.
Decrease received reference's ref count and allocate new
output variable */
...
...
@@ -4218,7 +4212,7 @@ class AllocEmpty(COp):
{fail};
}}
}}
"""
.
format
(
**
locals
())
"""
return
str
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
...
...
pytensor/tensor/blas.py
浏览文件 @
55b2f4fa
...
...
@@ -1806,19 +1806,12 @@ class BatchedDot(COp):
strides
=
f
"PyArray_STRIDES({var})"
if
ndim
==
1
:
return
f
"{strides}[0] == type_size"
return
" && "
.
join
(
[
" && "
.
join
(
ands
=
" && "
.
join
(
f
"{strides}[{i}] > 0 && {strides}[{i}]
%
type_size == 0"
for
i
in
range
(
1
,
ndim
)
),
"({})"
.
format
(
" || "
.
join
(
f
"{strides}[{i}] == type_size"
for
i
in
range
(
1
,
ndim
)
)
),
]
)
ors
=
" || "
.
join
(
f
"{strides}[{i}] == type_size"
for
i
in
range
(
1
,
ndim
))
return
f
"{ands} && ({ors})"
x_ndim
,
y_ndim
,
z_ndim
=
(
node
.
inputs
[
0
]
.
ndim
,
...
...
@@ -1838,7 +1831,7 @@ class BatchedDot(COp):
)
z_shape
=
", "
.
join
(
z_dims
)
z_contiguous
=
contiguous
(
_z
,
z_ndim
)
allocate
=
"""
allocate
=
f
"""
if (NULL == {_z} || !({z_shape_correct}) || !({z_contiguous}))
{{
npy_intp dims[{z_ndim}] = {{{z_shape}}};
...
...
@@ -1851,14 +1844,14 @@ class BatchedDot(COp):
{fail}
}}
}}
"""
.
format
(
**
locals
())
"""
# code to reallocate inputs contiguously if necessary
contiguate
=
[]
for
var
,
ndim
in
[(
_x
,
x_ndim
),
(
_y
,
y_ndim
)]:
_contiguous
=
contiguous
(
var
,
ndim
)
contiguate
.
append
(
"""
f
"""
if (!({_contiguous})) {{
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy({var});
if (!_copy)
...
...
@@ -1866,11 +1859,11 @@ class BatchedDot(COp):
Py_XDECREF({var});
{var} = _copy;
}}
"""
.
format
(
**
locals
())
"""
)
contiguate
=
"
\n
"
.
join
(
contiguate
)
return
"""
return
f
"""
int type_num = PyArray_DESCR({_x})->type_num;
int type_size = PyArray_DESCR({_x})->elsize; // in bytes
...
...
@@ -1927,7 +1920,7 @@ class BatchedDot(COp):
}}
break;
}}
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
from
pytensor.tensor.blas_headers
import
blas_header_version
...
...
pytensor/tensor/blas_c.py
浏览文件 @
55b2f4fa
...
...
@@ -33,7 +33,7 @@ class BaseBLAS(COp):
def
ger_c_code
(
A
,
a
,
x
,
y
,
Z
,
fail
,
params
):
return
"""
return
f
"""
int elemsize ;
...
...
@@ -309,7 +309,7 @@ def ger_c_code(A, a, x, y, Z, fail, params):
}}
}}
"""
.
format
(
**
locals
())
"""
class
CGer
(
BaseBLAS
,
Ger
):
...
...
pytensor/tensor/blas_headers.py
浏览文件 @
55b2f4fa
...
...
@@ -1066,7 +1066,7 @@ def blas_header_version():
def
____gemm_code
(
check_ab
,
a_init
,
b_init
):
mod
=
"
%
"
return
"""
return
f
"""
const char * error_string = NULL;
int type_num = PyArray_DESCR(_x)->type_num;
...
...
@@ -1203,4 +1203,4 @@ def ____gemm_code(check_ab, a_init, b_init):
return -1;
/* v 1 */
"""
.
format
(
**
locals
())
"""
pytensor/tensor/elemwise.py
浏览文件 @
55b2f4fa
...
...
@@ -302,10 +302,7 @@ class DimShufflePrinter(Printer):
return
pstate
.
pprinter
.
process
(
r
)
if
list
(
new_order
)
==
list
(
reversed
(
range
(
r
.
type
.
ndim
))):
return
f
"{pstate.pprinter.process(r)}.T"
return
"DimShuffle{{{}}}({})"
.
format
(
", "
.
join
(
map
(
str
,
new_order
)),
pstate
.
pprinter
.
process
(
r
),
)
return
f
"DimShuffle{{{', '.join(str(o) for o in new_order)}}}({pstate.pprinter.process(r)})"
def
process
(
self
,
r
,
pstate
):
if
r
.
owner
is
None
:
...
...
@@ -929,13 +926,13 @@ class Elemwise(OpenMPOp):
# We make the output point to the corresponding input and
# decrease the reference of whatever the output contained
# prior to this
alloc
+=
"""
alloc
+=
f
"""
if ({oname}) {{
Py_XDECREF({oname});
}}
{oname} = {iname};
Py_XINCREF({oname});
"""
.
format
(
**
locals
())
"""
# We alias the scalar variables
defines
+=
f
"#define {oname}_i {iname}_i
\n
"
undefs
+=
f
"#undef {oname}_i
\n
"
...
...
@@ -958,13 +955,13 @@ class Elemwise(OpenMPOp):
[
f
"{s}_i"
for
s
in
onames
],
dict
(
sub
,
fail
=
fail
),
)
code
=
"""
code
=
f
"""
{{
{defines}
{task_code}
{undefs}
}}
"""
.
format
(
**
locals
())
"""
loop_orders
=
orders
+
[
list
(
range
(
nnested
))]
*
len
(
real_onames
)
dtypes
=
idtypes
+
list
(
real_odtypes
)
...
...
@@ -994,19 +991,17 @@ class Elemwise(OpenMPOp):
if
index
!=
"x"
:
preloops
.
setdefault
(
j
,
""
)
preloops
[
j
]
+=
(
"
%
(lv{i})s_iter = ({dtype}*)"
"(PyArray_DATA(
%
(lv{i})s));
\n
"
.
format
(
**
locals
())
f
"
%
(lv{i})s_iter = ({dtype}*)(PyArray_DATA(
%
(lv{i})s));
\n
"
)
%
sub
break
else
:
# all broadcastable
preloops
.
setdefault
(
0
,
""
)
preloops
[
0
]
+=
(
"
%
(lv{i})s_iter = ({dtype}*)"
"(PyArray_DATA(
%
(lv{i})s));
\n
"
.
format
(
**
locals
())
f
"
%
(lv{i})s_iter = ({dtype}*)(PyArray_DATA(
%
(lv{i})s));
\n
"
)
%
sub
init_array
=
preloops
.
get
(
0
,
" "
)
loop
=
"""
loop
=
f
"""
{{
{defines}
{init_array}
...
...
@@ -1014,7 +1009,7 @@ class Elemwise(OpenMPOp):
{task_code}
{undefs}
}}
"""
.
format
(
**
locals
())
"""
else
:
loop
=
cgen
.
make_loop
(
loop_orders
=
loop_orders
,
...
...
@@ -1074,25 +1069,25 @@ class Elemwise(OpenMPOp):
index
=
""
for
x
,
var
in
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
):
if
not
all
(
s
==
1
for
s
in
var
.
type
.
shape
):
contig
+=
"""
contig
+=
f
"""
dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x});
"""
.
format
(
**
locals
())
index
+=
"""
"""
index
+=
f
"""
dtype_{x}& {x}_i = {x}_ptr[i];
"""
.
format
(
**
locals
())
"""
else
:
contig
+=
"""
contig
+=
f
"""
dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0];
"""
.
format
(
**
locals
())
"""
if
self
.
openmp
:
contig
+=
f
"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)})
"""
contig
+=
"""
contig
+=
f
"""
for(int i=0; i<n; i++){{
{index}
{task_code};
}}
"""
.
format
(
**
locals
())
"""
if
contig
is
not
None
:
z
=
list
(
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
))
all_broadcastable
=
all
(
s
==
1
for
s
in
var
.
type
.
shape
)
...
...
@@ -1106,13 +1101,13 @@ class Elemwise(OpenMPOp):
for
arr
,
var
in
z
if
not
all_broadcastable
)
loop
=
"""
loop
=
f
"""
if(({cond1}) || ({cond2})){{
{contig}
}}else{{
{loop}
}}
"""
.
format
(
**
locals
())
"""
return
decl
,
checks
,
alloc
,
loop
,
""
def
c_code
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
...
...
pytensor/tensor/elemwise_cgen.py
浏览文件 @
55b2f4fa
...
...
@@ -170,12 +170,14 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
type
=
type
.
replace
(
"PYTENSOR_COMPLEX"
,
"NPY_COMPLEX"
)
nd
=
len
(
loop_orders
[
0
])
init_dims
=
compute_output_dims_lengths
(
"dims"
,
loop_orders
,
sub
)
olv
=
sub
[
"olv"
]
fail
=
sub
[
"fail"
]
# TODO: it would be interesting to allocate the output in such a
# way that its contiguous dimensions match one of the input's
# contiguous dimensions, or the dimension with the smallest
# stride. Right now, it is allocated to be C_CONTIGUOUS.
return
"""
return
f
"""
{{
npy_intp dims[{nd}];
//npy_intp* dims = (npy_intp*)malloc({nd} * sizeof(npy_intp));
...
...
@@ -203,7 +205,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
{fail}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
make_loop
(
loop_orders
,
dtypes
,
loop_tasks
,
sub
,
openmp
=
None
):
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
55b2f4fa
...
...
@@ -68,7 +68,7 @@ class CpuContiguous(COp):
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
(
x
,)
=
inames
(
y
,)
=
onames
code
=
"""
code
=
f
"""
if (!PyArray_CHKFLAGS({x}, NPY_ARRAY_C_CONTIGUOUS)){{
// check to see if output is contiguous first
if ({y} != NULL &&
...
...
@@ -86,7 +86,7 @@ class CpuContiguous(COp):
Py_XDECREF({y});
{y} = {x};
}}
"""
.
format
(
**
locals
())
"""
return
code
def
c_code_cache_version
(
self
):
...
...
@@ -161,13 +161,13 @@ class SearchsortedOp(COp):
def
c_init_code_struct
(
self
,
node
,
name
,
sub
):
side
=
sub
[
"params"
]
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
PyObject* tmp_{name} = PyUnicode_FromString("right");
if (tmp_{name} == NULL)
{fail};
right_{name} = PyUnicode_Compare({side}, tmp_{name});
Py_DECREF(tmp_{name});
"""
.
format
(
**
locals
())
"""
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
sorter
=
None
...
...
@@ -180,7 +180,7 @@ class SearchsortedOp(COp):
(
z
,)
=
onames
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SearchSorted({x}, (PyObject*) {v},
right_{name} ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) {sorter});
...
...
@@ -191,7 +191,7 @@ class SearchsortedOp(COp):
Py_XDECREF({z});
{z} = (PyArrayObject*) tmp;
}}
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
return
(
2
,)
...
...
@@ -348,11 +348,10 @@ class CumOp(COp):
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
(
x
,)
=
inames
(
z
,)
=
onames
axis
=
self
.
axis
fail
=
sub
[
"fail"
]
params
=
sub
[
"params"
]
code
=
"""
code
=
f
"""
int axis = {params}->c_axis;
if (axis == 0 && PyArray_NDIM({x}) == 1)
axis = NPY_MAXDIMS;
...
...
@@ -389,7 +388,7 @@ class CumOp(COp):
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
}}
"""
.
format
(
**
locals
())
"""
return
code
...
...
pytensor/tensor/math.py
浏览文件 @
55b2f4fa
...
...
@@ -215,14 +215,14 @@ class Argmax(COp):
if
len
(
self
.
axis
)
!=
1
:
raise
NotImplementedError
()
# params is only used here for now
axis_code
=
"""
axis_code
=
f
"""
axis = {params}->c_axis;
if(axis > PyArray_NDIM({x})-1 || axis < -PyArray_NDIM({x})){{
PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument");
{fail}
}}
"""
.
format
(
**
locals
())
"""
ret
=
"""
int axis;
...
...
@@ -1314,7 +1314,8 @@ class Mean(FixedOpCAReduce):
def
__str__
(
self
):
if
self
.
axis
is
not
None
:
return
"Mean{{{}}}"
.
format
(
", "
.
join
(
str
(
x
)
for
x
in
self
.
axis
))
args
=
", "
.
join
(
str
(
x
)
for
x
in
self
.
axis
)
return
f
"Mean{{{args}}}"
else
:
return
"Mean"
...
...
pytensor/tensor/subtensor.py
浏览文件 @
55b2f4fa
...
...
@@ -1137,7 +1137,7 @@ class Subtensor(COp):
"""
rval
+=
"""
rval
+=
f
"""
// One more argument of the view
npy_intp xview_offset = 0;
...
...
@@ -1264,7 +1264,7 @@ class Subtensor(COp):
inner_ii += 1;
outer_ii += 1;
}}
"""
.
format
(
**
locals
())
"""
# print rval
return
rval
...
...
@@ -1284,19 +1284,19 @@ class Subtensor(COp):
decl
=
"PyArrayObject * xview = NULL;"
checkNDim
=
"""
checkNDim
=
f
"""
if (PyArray_NDIM({x}) != {ndim}){{
PyErr_SetString(PyExc_ValueError,
"Expected {ndim} dimensions input"
);
{fail}
}}
"""
.
format
(
**
locals
())
"""
get_xview
=
self
.
helper_c_code
(
node
,
name
,
inputs
,
outputs
,
sub
,
self
.
idx_list
,
view_ndim
)
build_view
=
"""
build_view
=
f
"""
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR({x}));
...
...
@@ -1314,7 +1314,7 @@ class Subtensor(COp):
{{
{fail};
}}
"""
.
format
(
**
locals
())
"""
finish_view
=
f
"""
Py_XDECREF({z});
...
...
@@ -1761,7 +1761,7 @@ class IncSubtensor(COp):
copy_of_x
=
self
.
copy_of_x
(
x
)
copy_input_if_necessary
=
"""
copy_input_if_necessary
=
f
"""
if ({inplace})
{{
if ({x} != {z})
...
...
@@ -1780,7 +1780,7 @@ class IncSubtensor(COp):
{fail}
}}
}}
"""
.
format
(
**
locals
())
"""
# get info needed to make zview: a view of %(z)s
helper_args
=
self
.
get_helper_c_code_args
()
...
...
@@ -1799,7 +1799,7 @@ class IncSubtensor(COp):
# Make a view on the output, as we will write into it.
alloc_zview
=
self
.
make_view_array
(
z
,
view_ndim
)
build_view
=
"""
build_view
=
f
"""
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
{alloc_zview};
...
...
@@ -1807,13 +1807,13 @@ class IncSubtensor(COp):
{{
{fail};
}}
"""
.
format
(
**
locals
())
"""
copy_into
=
self
.
copy_into
(
"zview"
,
y
)
add_to_zview
=
self
.
add_to_zview
(
name
,
y
,
fail
)
make_modification
=
"""
make_modification
=
f
"""
if ({op_is_set})
{{
if ({copy_into}) // does broadcasting
...
...
@@ -1826,7 +1826,7 @@ class IncSubtensor(COp):
{{
{add_to_zview}
}}
"""
.
format
(
**
locals
())
"""
return
(
self
.
decl_view
()
+
copy_input_if_necessary
...
...
@@ -1896,7 +1896,7 @@ class IncSubtensor(COp):
"""
return
"""Py_INCREF(PyArray_DESCR({x}));
return
f
"""Py_INCREF(PyArray_DESCR({x}));
zview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR({x}),
...
...
@@ -1906,7 +1906,7 @@ class IncSubtensor(COp):
PyArray_BYTES({x}) + xview_offset, //PyArray_DATA({x}),
PyArray_FLAGS({x}),
NULL);
"""
.
format
(
**
locals
())
"""
def
get_helper_c_code_args
(
self
):
"""
...
...
@@ -1939,7 +1939,7 @@ class IncSubtensor(COp):
"""
return
"""
return
f
"""
PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd(
(PyObject*)zview, py_{x});
if (add_rval)
...
...
@@ -1952,7 +1952,7 @@ class IncSubtensor(COp):
{{
Py_DECREF(zview);
{fail};
}}"""
.
format
(
**
locals
())
}}"""
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
0
]]
...
...
@@ -2162,7 +2162,7 @@ class AdvancedSubtensor1(COp):
a_name
,
i_name
=
input_names
[
0
],
input_names
[
1
]
output_name
=
output_names
[
0
]
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
PyArrayObject *indices;
int i_type = PyArray_TYPE({i_name});
if (i_type != NPY_INTP) {{
...
...
@@ -2237,7 +2237,7 @@ class AdvancedSubtensor1(COp):
{a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE);
Py_DECREF(indices);
if ({output_name} == NULL) {fail};
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
return
(
0
,
1
,
2
)
...
...
@@ -2523,8 +2523,10 @@ class AdvancedIncSubtensor1(COp):
x
,
y
,
idx
=
input_names
out
=
output_names
[
0
]
copy_of_x
=
self
.
copy_of_x
(
x
)
params
=
sub
[
"params"
]
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
PyObject* rval = NULL;
if ({params}->inplace)
{{
...
...
@@ -2548,17 +2550,7 @@ class AdvancedIncSubtensor1(COp):
{fail};
}}
Py_XDECREF(rval);
"""
.
format
(
**
dict
(
x
=
x
,
y
=
y
,
idx
=
idx
,
out
=
out
,
copy_of_x
=
copy_of_x
,
params
=
sub
[
"params"
],
fail
=
sub
[
"fail"
],
)
)
"""
def
c_code_cache_version
(
self
):
return
(
8
,)
...
...
pytensor/tensor/type.py
浏览文件 @
55b2f4fa
...
...
@@ -475,25 +475,28 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
if
check_input
:
check
=
"""
dtype
=
self
.
dtype_specs
()[
1
]
check
=
f
"""
typedef {dtype} dtype_{name};
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
]))
"""
else
:
check
=
""
declaration
=
"""
declaration
=
f
"""
PyArrayObject* {name};
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
]))
"""
return
declaration
+
check
def
c_init
(
self
,
name
,
sub
):
return
"""
return
f
"""
{name} = NULL;
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]))
"""
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
if
check_input
:
check
=
"""
fail
=
sub
[
"fail"
]
type_num
=
self
.
dtype_specs
()[
2
]
check
=
f
"""
{name} = NULL;
if (py_{name} == Py_None) {{
// We can either fail here or set {name} to NULL and rely on Ops
...
...
@@ -541,28 +544,27 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
{type_num}, PyArray_TYPE((PyArrayObject*) py_{name}));
{fail}
}}
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]))
"""
else
:
check
=
""
return
(
check
+
"""
+
f
"""
{name} = (PyArrayObject*)(py_{name});
Py_XINCREF({name});
"""
.
format
(
**
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
]))
"""
)
def
c_cleanup
(
self
,
name
,
sub
):
return
"""
return
f
"""
if ({name}) {{
Py_XDECREF({name});
}}
"""
.
format
(
**
locals
())
"""
def
c_sync
(
self
,
name
,
sub
):
fail
=
sub
[
"fail"
]
type_num
=
self
.
dtype_specs
()[
2
]
return
"""
return
f
"""
{{Py_XDECREF(py_{name});}}
if (!{name}) {{
Py_INCREF(Py_None);
...
...
@@ -597,7 +599,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
);
{fail}
}}
"""
.
format
(
**
locals
())
"""
def
c_headers
(
self
,
**
kwargs
):
return
ps
.
get_scalar_type
(
self
.
dtype
)
.
c_headers
(
**
kwargs
)
...
...
pytensor/typed_list/basic.py
浏览文件 @
55b2f4fa
...
...
@@ -102,13 +102,13 @@ class GetItem(COp):
x_name
,
index
=
inp
[
0
],
inp
[
1
]
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
{output_name} = (typeof {output_name}) PyList_GetItem( (PyObject*) {x_name}, *((npy_int64 *) PyArray_DATA({index})));
if({output_name} == NULL){{
{fail}
}}
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
return
(
1
,)
...
...
@@ -169,16 +169,16 @@ class Append(COp):
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
init
=
"""
init
=
f
"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
"""
.
format
(
**
locals
())
"""
else
:
init
=
f
"""
{output_name} = {x_name};
"""
return
(
init
+
"""
+
f
"""
if({output_name}==NULL){{
{fail}
}};
...
...
@@ -186,7 +186,7 @@ class Append(COp):
{fail}
}};
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
)
def
c_code_cache_version
(
self
):
...
...
@@ -249,16 +249,16 @@ class Extend(COp):
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
init
=
"""
init
=
f
"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
"""
.
format
(
**
locals
())
"""
else
:
init
=
f
"""
{output_name} = {x_name};
"""
return
(
init
+
"""
+
f
"""
int i =0;
int length = PyList_GET_SIZE((PyObject*) {toAppend});
if({output_name}==NULL){{
...
...
@@ -270,7 +270,7 @@ class Extend(COp):
}};
}}
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
)
def
c_code_cache_version_
(
self
):
...
...
@@ -337,16 +337,16 @@ class Insert(COp):
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
init
=
"""
init
=
f
"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
"""
.
format
(
**
locals
())
"""
else
:
init
=
f
"""
{output_name} = {x_name};
"""
return
(
init
+
"""
+
f
"""
if({output_name}==NULL){{
{fail}
}};
...
...
@@ -354,7 +354,7 @@ class Insert(COp):
{fail}
}};
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
)
def
c_code_cache_version
(
self
):
...
...
@@ -465,16 +465,16 @@ class Reverse(COp):
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
if
not
self
.
inplace
:
init
=
"""
init
=
f
"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
"""
.
format
(
**
locals
())
"""
else
:
init
=
f
"""
{output_name} = {x_name};
"""
return
(
init
+
"""
+
f
"""
if({output_name}==NULL){{
{fail}
}};
...
...
@@ -482,7 +482,7 @@ class Reverse(COp):
{fail}
}};
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
)
def
c_code_cache_version
(
self
):
...
...
@@ -595,13 +595,12 @@ class Length(COp):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
x_name
=
inp
[
0
]
output_name
=
out
[
0
]
fail
=
sub
[
"fail"
]
return
"""
return
f
"""
if(!{output_name})
{output_name}=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA({output_name}))[0]=PyList_Size((PyObject*){x_name});
Py_INCREF({output_name});
"""
.
format
(
**
locals
())
"""
def
c_code_cache_version
(
self
):
return
(
1
,)
...
...
pytensor/typed_list/type.py
浏览文件 @
55b2f4fa
...
...
@@ -110,26 +110,27 @@ class TypedListType(CType):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
if
check_input
:
pre
=
"""
fail
=
sub
[
"fail"
]
pre
=
f
"""
if (!PyList_Check(py_{name})) {{
PyErr_SetString(PyExc_TypeError, "expected a list");
{fail}
}}"""
.
format
(
**
dict
(
name
=
name
,
fail
=
sub
[
"fail"
]))
}}"""
else
:
pre
=
""
return
(
pre
+
"""
+
f
"""
{name} = (PyListObject*) (py_{name});
"""
.
format
(
**
dict
(
name
=
name
,
fail
=
sub
[
"fail"
]))
"""
)
def
c_sync
(
self
,
name
,
sub
):
return
"""
return
f
"""
Py_XDECREF(py_{name});
py_{name} = (PyObject*)({name});
Py_INCREF(py_{name});
"""
.
format
(
**
dict
(
name
=
name
))
"""
def
c_cleanup
(
self
,
name
,
sub
):
return
""
...
...
tests/compile/test_debugmode.py
浏览文件 @
55b2f4fa
...
...
@@ -64,7 +64,8 @@ class BROKEN_ON_PURPOSE_Add(COp):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
a
,
b
=
inp
(
z
,)
=
out
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (PyArray_NDIM({a}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 1"); {fail};}}
if (PyArray_NDIM({b}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); {fail};}}
...
...
@@ -96,7 +97,7 @@ class BROKEN_ON_PURPOSE_Add(COp):
+ ((double*)PyArray_GETPTR1({b}, m))[0] ;
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
# inconsistent is a invalid op, whose perform and c_code do not match
...
...
@@ -632,7 +633,8 @@ class BrokenCImplementationAdd(COp):
a
,
b
=
inp
(
z
,)
=
out
debug
=
0
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
//printf("executing c_code
\\
n");
if (PyArray_NDIM({a}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); {fail};}}
if (PyArray_NDIM({b}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); {fail};}}
...
...
@@ -690,7 +692,7 @@ class BrokenCImplementationAdd(COp):
}}
}}
}}
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
class
VecAsRowAndCol
(
Op
):
...
...
tests/link/c/test_basic.py
浏览文件 @
55b2f4fa
...
...
@@ -39,7 +39,8 @@ class TDouble(CType):
return
str
(
data
)
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
,
**
kwargs
):
return
"""
fail
=
sub
[
"fail"
]
return
f
"""
if (!PyFloat_Check(py_{name})) {{
PyErr_SetString(PyExc_TypeError, "not a double!");
{fail}
...
...
@@ -47,7 +48,7 @@ class TDouble(CType):
{name} = PyFloat_AsDouble(py_{name});
{name}_bad_thing = NULL;
//printf("Extracting {name}
\\
n");
"""
.
format
(
**
dict
(
locals
(),
**
sub
))
"""
def
c_sync
(
self
,
name
,
sub
):
return
f
"""
...
...
tests/link/c/test_cmodule.py
浏览文件 @
55b2f4fa
...
...
@@ -189,12 +189,11 @@ def mock_system(request):
@pytest.fixture
()
def
cxx_search_dirs
(
blas_libs
,
mock_system
):
libext
=
{
"Linux"
:
"so"
,
"Windows"
:
"dll"
,
"Darwin"
:
"dylib"
}
libtemplate
=
f
"{{lib}}.{libext[mock_system]}"
libraries
=
[]
with
tempfile
.
TemporaryDirectory
()
as
d
:
flags
=
None
for
lib
in
blas_libs
:
lib_path
=
Path
(
d
)
/
libtemplate
.
format
(
lib
=
lib
)
lib_path
=
Path
(
d
)
/
f
"{lib}.{libext[mock_system]}"
lib_path
.
write_bytes
(
b
"1"
)
libraries
.
append
(
lib_path
)
if
flags
is
None
:
...
...
@@ -262,14 +261,13 @@ def test_default_blas_ldflags_no_cxx():
@pytest.fixture
()
def
windows_conda_libs
(
blas_libs
):
libtemplate
=
"{lib}.dll"
libraries
=
[]
with
tempfile
.
TemporaryDirectory
()
as
d
:
subdir
=
Path
(
d
)
/
"Library"
/
"bin"
subdir
.
mkdir
(
exist_ok
=
True
,
parents
=
True
)
flags
=
f
'-L"{subdir}"'
for
lib
in
blas_libs
:
lib_path
=
subdir
/
libtemplate
.
format
(
lib
=
lib
)
lib_path
=
subdir
/
f
"{lib}.dll"
lib_path
.
write_bytes
(
b
"1"
)
libraries
.
append
(
lib_path
)
flags
+=
f
" -l{lib}"
...
...
tests/link/c/test_op.py
浏览文件 @
55b2f4fa
...
...
@@ -79,10 +79,10 @@ class StructOp(COp):
return
f
"counter{name} = 0;"
def
c_code
(
self
,
node
,
name
,
input_names
,
outputs_names
,
sub
):
return
"""
{out} = counter{name};
return
f
"""
{out
puts_names[0]
} = counter{name};
counter{name}++;
"""
.
format
(
**
dict
(
out
=
outputs_names
[
0
],
name
=
name
))
"""
def
c_code_cache_version
(
self
):
return
(
1
,)
...
...
tests/link/c/test_params_type.py
浏览文件 @
55b2f4fa
...
...
@@ -73,30 +73,24 @@ class QuadraticOpFunc(COp):
"""
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
return
"""
{float_type} a = ({float_type}) (*(npy_float64*) PyArray_GETPTR1({coeff}->a, 0)); // 0-D TensorType.
{float_type} b = {coeff}->b; // ScalarType.
{float_type} c = ({float_type}) PyFloat_AsDouble({coeff}->c); // Generic.
X
=
inputs
[
0
]
Y
=
outputs
[
0
]
float_type
=
node
.
inputs
[
0
]
.
type
.
c_element_type
()
return
f
"""
{float_type} a = ({float_type}) (*(npy_float64*) PyArray_GETPTR1({sub['params']}->a, 0)); // 0-D TensorType.
{float_type} b = {sub['params']}->b; // ScalarType.
{float_type} c = ({float_type}) PyFloat_AsDouble({sub['params']}->c); // Generic.
Py_XDECREF({Y});
{Y} = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM({X}), PyArray_DIMS({X}), PyArray_TYPE({X}), PyArray_IS_F_CONTIGUOUS({X}));
if (PyArray_CopyInto({Y}, {X}) != 0) {{
PyErr_SetString(PyExc_RuntimeError, "Unable to copy input into output.");
{
fail
}
{
sub['fail']
}
}};
if (quadratic_{name}({Y}, a, b, c) != 0) {{
PyErr_SetString(PyExc_RuntimeError, "Unable to compute quadratic function.");
{
fail
}
{
sub['fail']
}
}}
"""
.
format
(
**
dict
(
name
=
name
,
coeff
=
sub
[
"params"
],
fail
=
sub
[
"fail"
],
X
=
inputs
[
0
],
Y
=
outputs
[
0
],
float_type
=
node
.
inputs
[
0
]
.
type
.
c_element_type
(),
)
)
"""
# Same op as above, but implemented as a ExternalCOp (with C code in an
...
...
tests/link/c/test_type.py
浏览文件 @
55b2f4fa
...
...
@@ -25,11 +25,12 @@ Py_XDECREF((PyObject *)p);
"""
def
c_code
(
self
,
node
,
name
,
inps
,
outs
,
sub
):
return
"""
Py_XDECREF({out});
{out} = (void *){inp};
Py_INCREF({inp});
"""
.
format
(
**
dict
(
out
=
outs
[
0
],
inp
=
inps
[
0
]))
return
f
"""
Py_XDECREF({outs[0]});
{outs[0]} = (void *){inps[0]};
Py_INCREF({inps[0]});
"""
# FIXME: should it not be outs[0]?
def
c_code_cache_version
(
self
):
return
(
0
,)
...
...
@@ -52,11 +53,11 @@ Py_XDECREF((PyObject *)p);
"""
def
c_code
(
self
,
node
,
name
,
inps
,
outs
,
sub
):
return
"""
Py_XDECREF({out});
{out
} = (PyArrayObject *){inp
};
Py_INCREF({out});
"""
.
format
(
**
dict
(
out
=
outs
[
0
],
inp
=
inps
[
0
]))
return
f
"""
Py_XDECREF({out
s[0]
});
{out
s[0]} = (PyArrayObject *){inps[0]
};
Py_INCREF({out
s[0]
});
"""
def
c_code_cache_version
(
self
):
return
(
0
,)
...
...
@@ -136,7 +137,12 @@ class MyOpEnumList(COp):
return
(
1
,)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
return
"""
op
=
sub
[
"params"
]
o
=
outputs
[
0
]
a
=
inputs
[
0
]
b
=
inputs
[
1
]
fail
=
sub
[
"fail"
]
return
f
"""
switch({op}) {{
case ADD:
{o} = {a} + {b};
...
...
@@ -154,15 +160,7 @@ class MyOpEnumList(COp):
{{{fail}}}
break;
}}
"""
.
format
(
**
dict
(
op
=
sub
[
"params"
],
o
=
outputs
[
0
],
a
=
inputs
[
0
],
b
=
inputs
[
1
],
fail
=
sub
[
"fail"
],
)
)
"""
class
MyOpCEnumType
(
COp
):
...
...
@@ -201,15 +199,10 @@ class MyOpCEnumType(COp):
return
(
3
,)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
return
"""
{o} = {val};
"""
.
format
(
**
dict
(
o
=
outputs
[
0
],
# params in C code will already contains expected C constant value.
val
=
sub
[
"params"
],
)
)
return
f
"""
{outputs[0]} = {sub['params']};
"""
class
TestEnumTypes
:
...
...
tests/tensor/conv/c_conv3d_corr3d_ref.py
浏览文件 @
55b2f4fa
...
...
@@ -323,7 +323,9 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
)
depth
=
"-1"
return
"""
fail
=
sub
[
"fail"
]
params
=
sub
[
"params"
]
return
f
"""
// Mandatory args
int direction = {params}->direction; // forward, bprop weights, bprop inputs
...
...
@@ -568,18 +570,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
}}
assert (out2 == *out);
"""
.
format
(
**
dict
(
bottom
=
bottom
,
weights
=
weights
,
top
=
top
,
height
=
height
,
width
=
width
,
depth
=
depth
,
fail
=
sub
[
"fail"
],
params
=
sub
[
"params"
],
)
)
"""
class
Corr3dMM
(
BaseCorr3dMM
):
...
...
tests/tensor/conv/c_conv_corr_ref.py
浏览文件 @
55b2f4fa
...
...
@@ -322,7 +322,9 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
)
width
=
"-1"
return
"""
fail
=
sub
[
"fail"
]
params
=
sub
[
"params"
]
return
f
"""
// Mandatory args
int direction = {params}->direction; // forward, bprop weights, bprop inputs
...
...
@@ -614,17 +616,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
}}
assert (out2 == *out);
"""
.
format
(
**
dict
(
bottom
=
bottom
,
weights
=
weights
,
top
=
top
,
height
=
height
,
width
=
width
,
fail
=
sub
[
"fail"
],
params
=
sub
[
"params"
],
)
)
"""
class
CorrMM
(
BaseCorrMM
):
...
...
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
55b2f4fa
...
...
@@ -1391,9 +1391,9 @@ class TimesN(ps.basic.UnaryScalarOp):
def
c_support_code_apply
(
self
,
node
,
nodename
):
n
=
str
(
self
.
n
)
return
"""
return
f
"""
float {nodename}_timesn(float x) {{ return x * {n}; }}
"""
.
format
(
**
locals
())
"""
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
(
x
,)
=
inputs
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论