Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
10f285a1
提交
10f285a1
authored
7月 07, 2024
作者:
Virgile Andreani
提交者:
Virgile Andreani
7月 09, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use generators when appropriate
上级
8ae2a195
隐藏空白字符变更
内嵌
并排
正在显示
31 个修改的文件
包含
96 行增加
和
134 行删除
+96
-134
configparser.py
pytensor/configparser.py
+1
-1
formatting.py
pytensor/d3viz/formatting.py
+1
-1
basic.py
pytensor/graph/rewriting/basic.py
+3
-3
basic.py
pytensor/link/c/basic.py
+5
-5
cmodule.py
pytensor/link/c/cmodule.py
+3
-3
params_type.py
pytensor/link/c/params_type.py
+4
-10
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+2
-2
scalar.py
pytensor/link/numba/dispatch/scalar.py
+3
-7
scan.py
pytensor/link/numba/dispatch/scan.py
+1
-1
tensor_basic.py
pytensor/link/numba/dispatch/tensor_basic.py
+4
-8
vectorize_codegen.py
pytensor/link/numba/dispatch/vectorize_codegen.py
+2
-4
vm.py
pytensor/link/vm.py
+3
-5
check_duplicate_key.py
pytensor/misc/check_duplicate_key.py
+2
-2
printing.py
pytensor/printing.py
+13
-13
basic.py
pytensor/scalar/basic.py
+4
-4
loop.py
pytensor/scalar/loop.py
+1
-1
rewriting.py
pytensor/scan/rewriting.py
+1
-1
utils.py
pytensor/scan/utils.py
+7
-9
basic.py
pytensor/tensor/basic.py
+2
-2
blas.py
pytensor/tensor/blas.py
+1
-1
blockwise.py
pytensor/tensor/blockwise.py
+2
-4
elemwise.py
pytensor/tensor/elemwise.py
+14
-22
extra_ops.py
pytensor/tensor/extra_ops.py
+1
-1
math.py
pytensor/tensor/math.py
+3
-5
shape.py
pytensor/tensor/shape.py
+2
-4
slinalg.py
pytensor/tensor/slinalg.py
+1
-1
subtensor.py
pytensor/tensor/subtensor.py
+3
-5
type.py
pytensor/tensor/type.py
+1
-1
variable.py
pytensor/tensor/variable.py
+4
-6
test_math.py
tests/tensor/test_math.py
+1
-1
utils.py
tests/tensor/utils.py
+1
-1
没有找到文件。
pytensor/configparser.py
浏览文件 @
10f285a1
...
@@ -104,7 +104,7 @@ class PyTensorConfigParser:
...
@@ -104,7 +104,7 @@ class PyTensorConfigParser:
)
)
return
hash_from_code
(
return
hash_from_code
(
"
\n
"
.
join
(
"
\n
"
.
join
(
[
f
"{cv.name} = {cv.__get__(self, self.__class__)}"
for
cv
in
all_opts
]
f
"{cv.name} = {cv.__get__(self, self.__class__)}"
for
cv
in
all_opts
)
)
)
)
...
...
pytensor/d3viz/formatting.py
浏览文件 @
10f285a1
...
@@ -360,7 +360,7 @@ def dict_to_pdnode(d):
...
@@ -360,7 +360,7 @@ def dict_to_pdnode(d):
for
k
,
v
in
d
.
items
():
for
k
,
v
in
d
.
items
():
if
v
is
not
None
:
if
v
is
not
None
:
if
isinstance
(
v
,
list
):
if
isinstance
(
v
,
list
):
v
=
"
\t
"
.
join
(
[
str
(
x
)
for
x
in
v
]
)
v
=
"
\t
"
.
join
(
str
(
x
)
for
x
in
v
)
else
:
else
:
v
=
str
(
v
)
v
=
str
(
v
)
v
=
str
(
v
)
v
=
str
(
v
)
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
10f285a1
...
@@ -1264,7 +1264,7 @@ class SequentialNodeRewriter(NodeRewriter):
...
@@ -1264,7 +1264,7 @@ class SequentialNodeRewriter(NodeRewriter):
return
getattr
(
return
getattr
(
self
,
self
,
"__name__"
,
"__name__"
,
f
"{type(self).__name__}({','.join(
[str(o) for o in self.rewrites]
)})"
,
f
"{type(self).__name__}({','.join(
str(o) for o in self.rewrites
)})"
,
)
)
def
tracks
(
self
):
def
tracks
(
self
):
...
@@ -1666,7 +1666,7 @@ class PatternNodeRewriter(NodeRewriter):
...
@@ -1666,7 +1666,7 @@ class PatternNodeRewriter(NodeRewriter):
if
isinstance
(
pattern
,
list
|
tuple
):
if
isinstance
(
pattern
,
list
|
tuple
):
return
"{}({})"
.
format
(
return
"{}({})"
.
format
(
str
(
pattern
[
0
]),
str
(
pattern
[
0
]),
", "
.
join
(
[
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]
]),
", "
.
join
(
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:
]),
)
)
elif
isinstance
(
pattern
,
dict
):
elif
isinstance
(
pattern
,
dict
):
return
"{} subject to {}"
.
format
(
return
"{} subject to {}"
.
format
(
...
@@ -2569,7 +2569,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
...
@@ -2569,7 +2569,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
d
=
sorted
(
d
=
sorted
(
loop_process_count
[
i
]
.
items
(),
key
=
lambda
a
:
a
[
1
],
reverse
=
True
loop_process_count
[
i
]
.
items
(),
key
=
lambda
a
:
a
[
1
],
reverse
=
True
)
)
loop_times
=
" "
.
join
(
[
str
((
str
(
k
),
v
))
for
k
,
v
in
d
[:
5
]
])
loop_times
=
" "
.
join
(
str
((
str
(
k
),
v
))
for
k
,
v
in
d
[:
5
])
if
len
(
d
)
>
5
:
if
len
(
d
)
>
5
:
loop_times
+=
" ..."
loop_times
+=
" ..."
print
(
print
(
...
...
pytensor/link/c/basic.py
浏览文件 @
10f285a1
...
@@ -235,16 +235,16 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -235,16 +235,16 @@ def struct_gen(args, struct_builders, blocks, sub):
behavior
=
code_gen
(
blocks
)
behavior
=
code_gen
(
blocks
)
# declares the storage
# declares the storage
storage_decl
=
"
\n
"
.
join
(
[
f
"PyObject* {arg};"
for
arg
in
args
]
)
storage_decl
=
"
\n
"
.
join
(
f
"PyObject* {arg};"
for
arg
in
args
)
# in the constructor, sets the storage to the arguments
# in the constructor, sets the storage to the arguments
storage_set
=
"
\n
"
.
join
(
[
f
"this->{arg} = {arg};"
for
arg
in
args
]
)
storage_set
=
"
\n
"
.
join
(
f
"this->{arg} = {arg};"
for
arg
in
args
)
# increments the storage's refcount in the constructor
# increments the storage's refcount in the constructor
storage_incref
=
"
\n
"
.
join
(
[
f
"Py_XINCREF({arg});"
for
arg
in
args
]
)
storage_incref
=
"
\n
"
.
join
(
f
"Py_XINCREF({arg});"
for
arg
in
args
)
# decrements the storage's refcount in the destructor
# decrements the storage's refcount in the destructor
storage_decref
=
"
\n
"
.
join
(
[
f
"Py_XDECREF(this->{arg});"
for
arg
in
args
]
)
storage_decref
=
"
\n
"
.
join
(
f
"Py_XDECREF(this->{arg});"
for
arg
in
args
)
args_names
=
", "
.
join
(
args
)
args_names
=
", "
.
join
(
args
)
args_decl
=
", "
.
join
(
[
f
"PyObject* {arg}"
for
arg
in
args
]
)
args_decl
=
", "
.
join
(
f
"PyObject* {arg}"
for
arg
in
args
)
# The following code stores the exception data in __ERROR, which
# The following code stores the exception data in __ERROR, which
# is a special field of the struct. __ERROR is a list of length 3
# is a special field of the struct. __ERROR is a list of length 3
...
...
pytensor/link/c/cmodule.py
浏览文件 @
10f285a1
...
@@ -2003,7 +2003,7 @@ def try_blas_flag(flags):
...
@@ -2003,7 +2003,7 @@ def try_blas_flag(flags):
cflags
=
list
(
flags
)
cflags
=
list
(
flags
)
# to support path that includes spaces, we need to wrap it with double quotes on Windows
# to support path that includes spaces, we need to wrap it with double quotes on Windows
path_wrapper
=
'"'
if
os
.
name
==
"nt"
else
""
path_wrapper
=
'"'
if
os
.
name
==
"nt"
else
""
cflags
.
extend
(
[
f
"-L{path_wrapper}{d}{path_wrapper}"
for
d
in
std_lib_dirs
()]
)
cflags
.
extend
(
f
"-L{path_wrapper}{d}{path_wrapper}"
for
d
in
std_lib_dirs
()
)
res
=
GCC_compiler
.
try_compile_tmp
(
res
=
GCC_compiler
.
try_compile_tmp
(
test_code
,
tmp_prefix
=
"try_blas_"
,
flags
=
cflags
,
try_run
=
True
test_code
,
tmp_prefix
=
"try_blas_"
,
flags
=
cflags
,
try_run
=
True
...
@@ -2573,8 +2573,8 @@ class GCC_compiler(Compiler):
...
@@ -2573,8 +2573,8 @@ class GCC_compiler(Compiler):
cmd
.
extend
(
preargs
)
cmd
.
extend
(
preargs
)
# to support path that includes spaces, we need to wrap it with double quotes on Windows
# to support path that includes spaces, we need to wrap it with double quotes on Windows
path_wrapper
=
'"'
if
os
.
name
==
"nt"
else
""
path_wrapper
=
'"'
if
os
.
name
==
"nt"
else
""
cmd
.
extend
(
[
f
"-I{path_wrapper}{idir}{path_wrapper}"
for
idir
in
include_dirs
]
)
cmd
.
extend
(
f
"-I{path_wrapper}{idir}{path_wrapper}"
for
idir
in
include_dirs
)
cmd
.
extend
(
[
f
"-L{path_wrapper}{ldir}{path_wrapper}"
for
ldir
in
lib_dirs
]
)
cmd
.
extend
(
f
"-L{path_wrapper}{ldir}{path_wrapper}"
for
ldir
in
lib_dirs
)
if
hide_symbols
and
sys
.
platform
!=
"win32"
:
if
hide_symbols
and
sys
.
platform
!=
"win32"
:
# This has been available since gcc 4.0 so we suppose it
# This has been available since gcc 4.0 so we suppose it
# is always available. We pass it here since it
# is always available. We pass it here since it
...
...
pytensor/link/c/params_type.py
浏览文件 @
10f285a1
...
@@ -263,9 +263,7 @@ class Params(dict):
...
@@ -263,9 +263,7 @@ class Params(dict):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"Params({})"
.
format
(
return
"Params({})"
.
format
(
", "
.
join
(
", "
.
join
((
f
"{k}:{type(self[k]).__name__}:{self[k]}"
)
for
k
in
sorted
(
self
))
[(
f
"{k}:{type(self[k]).__name__}:{self[k]}"
)
for
k
in
sorted
(
self
)]
)
)
)
def
__getattr__
(
self
,
key
):
def
__getattr__
(
self
,
key
):
...
@@ -425,9 +423,7 @@ class ParamsType(CType):
...
@@ -425,9 +423,7 @@ class ParamsType(CType):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"ParamsType<{}>"
.
format
(
return
"ParamsType<{}>"
.
format
(
", "
.
join
(
", "
.
join
((
f
"{self.fields[i]}:{self.types[i]}"
)
for
i
in
range
(
self
.
length
))
[(
f
"{self.fields[i]}:{self.types[i]}"
)
for
i
in
range
(
self
.
length
)]
)
)
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -748,10 +744,8 @@ class ParamsType(CType):
...
@@ -748,10 +744,8 @@ class ParamsType(CType):
}}
}}
"""
.
format
(
"""
.
format
(
"
\n
"
.
join
(
"
\n
"
.
join
(
[
(
"case
%
d: extract_
%
s(object); break;"
%
(
i
,
self
.
fields
[
i
]))
(
"case
%
d: extract_
%
s(object); break;"
%
(
i
,
self
.
fields
[
i
]))
for
i
in
range
(
self
.
length
)
for
i
in
range
(
self
.
length
)
]
)
)
)
)
final_struct_code
=
"""
final_struct_code
=
"""
...
...
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
10f285a1
...
@@ -485,8 +485,8 @@ def numba_funcify_Elemwise(op, node, **kwargs):
...
@@ -485,8 +485,8 @@ def numba_funcify_Elemwise(op, node, **kwargs):
nout
=
len
(
node
.
outputs
)
nout
=
len
(
node
.
outputs
)
core_op_fn
=
store_core_outputs
(
scalar_op_fn
,
nin
=
nin
,
nout
=
nout
)
core_op_fn
=
store_core_outputs
(
scalar_op_fn
,
nin
=
nin
,
nout
=
nout
)
input_bc_patterns
=
tuple
(
[
inp
.
type
.
broadcastable
for
inp
in
node
.
inputs
]
)
input_bc_patterns
=
tuple
(
inp
.
type
.
broadcastable
for
inp
in
node
.
inputs
)
output_bc_patterns
=
tuple
(
[
out
.
type
.
broadcastable
for
out
in
node
.
outputs
]
)
output_bc_patterns
=
tuple
(
out
.
type
.
broadcastable
for
out
in
node
.
outputs
)
output_dtypes
=
tuple
(
out
.
type
.
dtype
for
out
in
node
.
outputs
)
output_dtypes
=
tuple
(
out
.
type
.
dtype
for
out
in
node
.
outputs
)
inplace_pattern
=
tuple
(
op
.
inplace_pattern
.
items
())
inplace_pattern
=
tuple
(
op
.
inplace_pattern
.
items
())
core_output_shapes
=
tuple
(()
for
_
in
range
(
nout
))
core_output_shapes
=
tuple
(()
for
_
in
range
(
nout
))
...
...
pytensor/link/numba/dispatch/scalar.py
浏览文件 @
10f285a1
...
@@ -85,9 +85,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
...
@@ -85,9 +85,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
unique_names
=
unique_name_generator
(
unique_names
=
unique_name_generator
(
[
scalar_op_fn_name
,
"scalar_func_numba"
],
suffix_sep
=
"_"
[
scalar_op_fn_name
,
"scalar_func_numba"
],
suffix_sep
=
"_"
)
)
input_names
=
", "
.
join
(
input_names
=
", "
.
join
(
unique_names
(
v
,
force_unique
=
True
)
for
v
in
node
.
inputs
)
[
unique_names
(
v
,
force_unique
=
True
)
for
v
in
node
.
inputs
]
)
if
not
has_pyx_skip_dispatch
:
if
not
has_pyx_skip_dispatch
:
scalar_op_src
=
f
"""
scalar_op_src
=
f
"""
def {scalar_op_fn_name}({input_names}):
def {scalar_op_fn_name}({input_names}):
...
@@ -115,10 +113,8 @@ def {scalar_op_fn_name}({input_names}):
...
@@ -115,10 +113,8 @@ def {scalar_op_fn_name}({input_names}):
input_names
=
[
unique_names
(
v
,
force_unique
=
True
)
for
v
in
node
.
inputs
]
input_names
=
[
unique_names
(
v
,
force_unique
=
True
)
for
v
in
node
.
inputs
]
converted_call_args
=
", "
.
join
(
converted_call_args
=
", "
.
join
(
[
f
"direct_cast({i_name}, {i_tmp_dtype_name})"
f
"direct_cast({i_name}, {i_tmp_dtype_name})"
for
i_name
,
i_tmp_dtype_name
in
zip
(
input_names
,
input_tmp_dtype_names
)
for
i_name
,
i_tmp_dtype_name
in
zip
(
input_names
,
input_tmp_dtype_names
)
]
)
)
if
not
has_pyx_skip_dispatch
:
if
not
has_pyx_skip_dispatch
:
scalar_op_src
=
f
"""
scalar_op_src
=
f
"""
...
...
pytensor/link/numba/dispatch/scan.py
浏览文件 @
10f285a1
...
@@ -373,7 +373,7 @@ def numba_funcify_Scan(op, node, **kwargs):
...
@@ -373,7 +373,7 @@ def numba_funcify_Scan(op, node, **kwargs):
inner_out_post_processing_block
=
"
\n
"
.
join
(
inner_out_post_processing_stmts
)
inner_out_post_processing_block
=
"
\n
"
.
join
(
inner_out_post_processing_stmts
)
inner_out_to_outer_out_stmts
=
"
\n
"
.
join
(
inner_out_to_outer_out_stmts
=
"
\n
"
.
join
(
[
f
"{s} = {d}"
for
s
,
d
in
zip
(
inner_out_to_outer_in_stmts
,
inner_output_names
)]
f
"{s} = {d}"
for
s
,
d
in
zip
(
inner_out_to_outer_in_stmts
,
inner_output_names
)
)
)
scan_op_src
=
f
"""
scan_op_src
=
f
"""
...
...
pytensor/link/numba/dispatch/tensor_basic.py
浏览文件 @
10f285a1
...
@@ -35,10 +35,8 @@ def numba_funcify_AllocEmpty(op, node, **kwargs):
...
@@ -35,10 +35,8 @@ def numba_funcify_AllocEmpty(op, node, **kwargs):
shape_var_item_names
=
[
f
"{name}_item"
for
name
in
shape_var_names
]
shape_var_item_names
=
[
f
"{name}_item"
for
name
in
shape_var_names
]
shapes_to_items_src
=
indent
(
shapes_to_items_src
=
indent
(
"
\n
"
.
join
(
"
\n
"
.
join
(
[
f
"{item_name} = to_scalar({shape_name})"
f
"{item_name} = to_scalar({shape_name})"
for
item_name
,
shape_name
in
zip
(
shape_var_item_names
,
shape_var_names
)
for
item_name
,
shape_name
in
zip
(
shape_var_item_names
,
shape_var_names
)
]
),
),
" "
*
4
,
" "
*
4
,
)
)
...
@@ -69,10 +67,8 @@ def numba_funcify_Alloc(op, node, **kwargs):
...
@@ -69,10 +67,8 @@ def numba_funcify_Alloc(op, node, **kwargs):
shape_var_item_names
=
[
f
"{name}_item"
for
name
in
shape_var_names
]
shape_var_item_names
=
[
f
"{name}_item"
for
name
in
shape_var_names
]
shapes_to_items_src
=
indent
(
shapes_to_items_src
=
indent
(
"
\n
"
.
join
(
"
\n
"
.
join
(
[
f
"{item_name} = to_scalar({shape_name})"
f
"{item_name} = to_scalar({shape_name})"
for
item_name
,
shape_name
in
zip
(
shape_var_item_names
,
shape_var_names
)
for
item_name
,
shape_name
in
zip
(
shape_var_item_names
,
shape_var_names
)
]
),
),
" "
*
4
,
" "
*
4
,
)
)
...
...
pytensor/link/numba/dispatch/vectorize_codegen.py
浏览文件 @
10f285a1
...
@@ -43,10 +43,8 @@ def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable:
...
@@ -43,10 +43,8 @@ def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable:
out_signature
=
", "
.
join
(
outputs
)
out_signature
=
", "
.
join
(
outputs
)
inner_out_signature
=
", "
.
join
(
inner_outputs
)
inner_out_signature
=
", "
.
join
(
inner_outputs
)
store_outputs
=
"
\n
"
.
join
(
store_outputs
=
"
\n
"
.
join
(
[
f
"{output}[...] = {inner_output}"
f
"{output}[...] = {inner_output}"
for
output
,
inner_output
in
zip
(
outputs
,
inner_outputs
)
for
output
,
inner_output
in
zip
(
outputs
,
inner_outputs
)
]
)
)
func_src
=
f
"""
func_src
=
f
"""
def store_core_outputs({inp_signature}, {out_signature}):
def store_core_outputs({inp_signature}, {out_signature}):
...
...
pytensor/link/vm.py
浏览文件 @
10f285a1
...
@@ -1112,7 +1112,7 @@ class VMLinker(LocalLinker):
...
@@ -1112,7 +1112,7 @@ class VMLinker(LocalLinker):
for
i
,
node
in
enumerate
(
nodes
):
for
i
,
node
in
enumerate
(
nodes
):
prereq_var_idxs
=
[]
prereq_var_idxs
=
[]
for
prereq_node
in
ords
.
get
(
node
,
[]):
for
prereq_node
in
ords
.
get
(
node
,
[]):
prereq_var_idxs
.
extend
(
[
vars_idx
[
v
]
for
v
in
prereq_node
.
outputs
]
)
prereq_var_idxs
.
extend
(
vars_idx
[
v
]
for
v
in
prereq_node
.
outputs
)
prereq_var_idxs
=
list
(
set
(
prereq_var_idxs
))
prereq_var_idxs
=
list
(
set
(
prereq_var_idxs
))
prereq_var_idxs
.
sort
()
# TODO: why sort?
prereq_var_idxs
.
sort
()
# TODO: why sort?
node_prereqs
.
append
(
prereq_var_idxs
)
node_prereqs
.
append
(
prereq_var_idxs
)
...
@@ -1323,9 +1323,7 @@ class VMLinker(LocalLinker):
...
@@ -1323,9 +1323,7 @@ class VMLinker(LocalLinker):
def
__repr__
(
self
):
def
__repr__
(
self
):
args_str
=
", "
.
join
(
args_str
=
", "
.
join
(
[
f
"{name}={getattr(self, name)}"
f
"{name}={getattr(self, name)}"
for
name
in
(
"use_cloop"
,
"lazy"
,
"allow_partial_eval"
,
"allow_gc"
)
for
name
in
(
"use_cloop"
,
"lazy"
,
"allow_partial_eval"
,
"allow_gc"
)
]
)
)
return
f
"{type(self).__name__}({args_str})"
return
f
"{type(self).__name__}({args_str})"
pytensor/misc/check_duplicate_key.py
浏览文件 @
10f285a1
...
@@ -9,10 +9,10 @@ from pytensor.configdefaults import config
...
@@ -9,10 +9,10 @@ from pytensor.configdefaults import config
DISPLAY_DUPLICATE_KEYS
=
False
DISPLAY_DUPLICATE_KEYS
=
False
DISPLAY_MOST_FREQUENT_DUPLICATE_CCODE
=
False
DISPLAY_MOST_FREQUENT_DUPLICATE_CCODE
=
False
dirs
=
[]
dirs
:
list
=
[]
if
len
(
sys
.
argv
)
>
1
:
if
len
(
sys
.
argv
)
>
1
:
for
compiledir
in
sys
.
argv
[
1
:]:
for
compiledir
in
sys
.
argv
[
1
:]:
dirs
.
extend
(
[
os
.
path
.
join
(
compiledir
,
d
)
for
d
in
os
.
listdir
(
compiledir
)]
)
dirs
.
extend
(
os
.
path
.
join
(
compiledir
,
d
)
for
d
in
os
.
listdir
(
compiledir
)
)
else
:
else
:
dirs
=
os
.
listdir
(
config
.
compiledir
)
dirs
=
os
.
listdir
(
config
.
compiledir
)
dirs
=
[
os
.
path
.
join
(
config
.
compiledir
,
d
)
for
d
in
dirs
]
dirs
=
[
os
.
path
.
join
(
config
.
compiledir
,
d
)
for
d
in
dirs
]
...
...
pytensor/printing.py
浏览文件 @
10f285a1
...
@@ -229,32 +229,32 @@ def debugprint(
...
@@ -229,32 +229,32 @@ def debugprint(
topo_orders
.
append
(
None
)
topo_orders
.
append
(
None
)
elif
isinstance
(
obj
,
Apply
):
elif
isinstance
(
obj
,
Apply
):
outputs_to_print
.
extend
(
obj
.
outputs
)
outputs_to_print
.
extend
(
obj
.
outputs
)
profile_list
.
extend
(
[
None
for
item
in
obj
.
outputs
]
)
profile_list
.
extend
(
None
for
item
in
obj
.
outputs
)
storage_maps
.
extend
(
[
None
for
item
in
obj
.
outputs
]
)
storage_maps
.
extend
(
None
for
item
in
obj
.
outputs
)
topo_orders
.
extend
(
[
None
for
item
in
obj
.
outputs
]
)
topo_orders
.
extend
(
None
for
item
in
obj
.
outputs
)
elif
isinstance
(
obj
,
Function
):
elif
isinstance
(
obj
,
Function
):
if
print_fgraph_inputs
:
if
print_fgraph_inputs
:
inputs_to_print
.
extend
(
obj
.
maker
.
fgraph
.
inputs
)
inputs_to_print
.
extend
(
obj
.
maker
.
fgraph
.
inputs
)
outputs_to_print
.
extend
(
obj
.
maker
.
fgraph
.
outputs
)
outputs_to_print
.
extend
(
obj
.
maker
.
fgraph
.
outputs
)
profile_list
.
extend
(
[
obj
.
profile
for
item
in
obj
.
maker
.
fgraph
.
outputs
]
)
profile_list
.
extend
(
obj
.
profile
for
item
in
obj
.
maker
.
fgraph
.
outputs
)
if
print_storage
:
if
print_storage
:
storage_maps
.
extend
(
storage_maps
.
extend
(
[
obj
.
vm
.
storage_map
for
item
in
obj
.
maker
.
fgraph
.
outputs
]
obj
.
vm
.
storage_map
for
item
in
obj
.
maker
.
fgraph
.
outputs
)
)
else
:
else
:
storage_maps
.
extend
(
[
None
for
item
in
obj
.
maker
.
fgraph
.
outputs
]
)
storage_maps
.
extend
(
None
for
item
in
obj
.
maker
.
fgraph
.
outputs
)
topo
=
obj
.
maker
.
fgraph
.
toposort
()
topo
=
obj
.
maker
.
fgraph
.
toposort
()
topo_orders
.
extend
(
[
topo
for
item
in
obj
.
maker
.
fgraph
.
outputs
]
)
topo_orders
.
extend
(
topo
for
item
in
obj
.
maker
.
fgraph
.
outputs
)
elif
isinstance
(
obj
,
FunctionGraph
):
elif
isinstance
(
obj
,
FunctionGraph
):
if
print_fgraph_inputs
:
if
print_fgraph_inputs
:
inputs_to_print
.
extend
(
obj
.
inputs
)
inputs_to_print
.
extend
(
obj
.
inputs
)
outputs_to_print
.
extend
(
obj
.
outputs
)
outputs_to_print
.
extend
(
obj
.
outputs
)
profile_list
.
extend
(
[
getattr
(
obj
,
"profile"
,
None
)
for
item
in
obj
.
outputs
]
)
profile_list
.
extend
(
getattr
(
obj
,
"profile"
,
None
)
for
item
in
obj
.
outputs
)
storage_maps
.
extend
(
storage_maps
.
extend
(
[
getattr
(
obj
,
"storage_map"
,
None
)
for
item
in
obj
.
outputs
]
getattr
(
obj
,
"storage_map"
,
None
)
for
item
in
obj
.
outputs
)
)
topo
=
obj
.
toposort
()
topo
=
obj
.
toposort
()
topo_orders
.
extend
(
[
topo
for
item
in
obj
.
outputs
]
)
topo_orders
.
extend
(
topo
for
item
in
obj
.
outputs
)
elif
isinstance
(
obj
,
int
|
float
|
np
.
ndarray
):
elif
isinstance
(
obj
,
int
|
float
|
np
.
ndarray
):
print
(
obj
,
file
=
_file
)
print
(
obj
,
file
=
_file
)
elif
isinstance
(
obj
,
In
|
Out
):
elif
isinstance
(
obj
,
In
|
Out
):
...
@@ -980,10 +980,10 @@ class FunctionPrinter(Printer):
...
@@ -980,10 +980,10 @@ class FunctionPrinter(Printer):
name
=
self
.
names
[
idx
]
name
=
self
.
names
[
idx
]
with
set_precedence
(
pstate
):
with
set_precedence
(
pstate
):
inputs_str
=
", "
.
join
(
inputs_str
=
", "
.
join
(
[
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
]
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
)
)
keywords_str
=
", "
.
join
(
keywords_str
=
", "
.
join
(
[
f
"{kw}={getattr(node.op, kw)}"
for
kw
in
self
.
keywords
]
f
"{kw}={getattr(node.op, kw)}"
for
kw
in
self
.
keywords
)
)
if
keywords_str
and
inputs_str
:
if
keywords_str
and
inputs_str
:
...
@@ -1050,7 +1050,7 @@ class DefaultPrinter(Printer):
...
@@ -1050,7 +1050,7 @@ class DefaultPrinter(Printer):
with
set_precedence
(
pstate
):
with
set_precedence
(
pstate
):
r
=
"{}({})"
.
format
(
r
=
"{}({})"
.
format
(
str
(
node
.
op
),
str
(
node
.
op
),
", "
.
join
(
[
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
]
),
", "
.
join
(
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
),
)
)
pstate
.
memo
[
output
]
=
r
pstate
.
memo
[
output
]
=
r
...
...
pytensor/scalar/basic.py
浏览文件 @
10f285a1
...
@@ -4224,8 +4224,8 @@ class Composite(ScalarInnerGraphOp):
...
@@ -4224,8 +4224,8 @@ class Composite(ScalarInnerGraphOp):
inputs
,
outputs
=
res
[
0
],
res2
[
1
]
inputs
,
outputs
=
res
[
0
],
res2
[
1
]
self
.
inputs
,
self
.
outputs
=
self
.
_cleanup_graph
(
inputs
,
outputs
)
self
.
inputs
,
self
.
outputs
=
self
.
_cleanup_graph
(
inputs
,
outputs
)
self
.
inputs_type
=
tuple
(
[
input
.
type
for
input
in
self
.
inputs
]
)
self
.
inputs_type
=
tuple
(
input
.
type
for
input
in
self
.
inputs
)
self
.
outputs_type
=
tuple
(
[
output
.
type
for
output
in
self
.
outputs
]
)
self
.
outputs_type
=
tuple
(
output
.
type
for
output
in
self
.
outputs
)
self
.
nin
=
len
(
inputs
)
self
.
nin
=
len
(
inputs
)
self
.
nout
=
len
(
outputs
)
self
.
nout
=
len
(
outputs
)
super
()
.
__init__
()
super
()
.
__init__
()
...
@@ -4247,7 +4247,7 @@ class Composite(ScalarInnerGraphOp):
...
@@ -4247,7 +4247,7 @@ class Composite(ScalarInnerGraphOp):
if
len
(
self
.
fgraph
.
outputs
)
>
1
or
len
(
self
.
fgraph
.
apply_nodes
)
>
10
:
if
len
(
self
.
fgraph
.
outputs
)
>
1
or
len
(
self
.
fgraph
.
apply_nodes
)
>
10
:
self
.
_name
=
"Composite{...}"
self
.
_name
=
"Composite{...}"
else
:
else
:
outputs_str
=
", "
.
join
(
[
pprint
(
output
)
for
output
in
self
.
fgraph
.
outputs
]
)
outputs_str
=
", "
.
join
(
pprint
(
output
)
for
output
in
self
.
fgraph
.
outputs
)
self
.
_name
=
f
"Composite{{{outputs_str}}}"
self
.
_name
=
f
"Composite{{{outputs_str}}}"
return
self
.
_name
return
self
.
_name
...
@@ -4295,7 +4295,7 @@ class Composite(ScalarInnerGraphOp):
...
@@ -4295,7 +4295,7 @@ class Composite(ScalarInnerGraphOp):
return
self
.
outputs_type
return
self
.
outputs_type
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
if
tuple
(
[
i
.
type
for
i
in
self
.
inputs
])
==
tuple
([
i
.
type
for
i
in
inputs
]
):
if
tuple
(
i
.
type
for
i
in
self
.
inputs
)
==
tuple
(
i
.
type
for
i
in
inputs
):
return
super
()
.
make_node
(
*
inputs
)
return
super
()
.
make_node
(
*
inputs
)
else
:
else
:
# Make a new op with the right input type.
# Make a new op with the right input type.
...
...
pytensor/scalar/loop.py
浏览文件 @
10f285a1
...
@@ -160,7 +160,7 @@ class ScalarLoop(ScalarInnerGraphOp):
...
@@ -160,7 +160,7 @@ class ScalarLoop(ScalarInnerGraphOp):
f
"Got {n_steps.type.dtype}"
,
f
"Got {n_steps.type.dtype}"
,
)
)
if
self
.
inputs_type
==
tuple
(
[
i
.
type
for
i
in
inputs
]
):
if
self
.
inputs_type
==
tuple
(
i
.
type
for
i
in
inputs
):
return
super
()
.
make_node
(
n_steps
,
*
inputs
)
return
super
()
.
make_node
(
n_steps
,
*
inputs
)
else
:
else
:
# Make a new op with the right input types.
# Make a new op with the right input types.
...
...
pytensor/scan/rewriting.py
浏览文件 @
10f285a1
...
@@ -1936,7 +1936,7 @@ class ScanMerge(GraphRewriter):
...
@@ -1936,7 +1936,7 @@ class ScanMerge(GraphRewriter):
profile
=
old_op
.
profile
,
profile
=
old_op
.
profile
,
truncate_gradient
=
old_op
.
truncate_gradient
,
truncate_gradient
=
old_op
.
truncate_gradient
,
allow_gc
=
old_op
.
allow_gc
,
allow_gc
=
old_op
.
allow_gc
,
name
=
"&"
.
join
(
[
nd
.
op
.
name
for
nd
in
nodes
]
),
name
=
"&"
.
join
(
nd
.
op
.
name
for
nd
in
nodes
),
)
)
new_outs
=
new_op
(
*
outer_ins
)
new_outs
=
new_op
(
*
outer_ins
)
...
...
pytensor/scan/utils.py
浏览文件 @
10f285a1
...
@@ -749,15 +749,13 @@ class ScanArgs:
...
@@ -749,15 +749,13 @@ class ScanArgs:
def
field_names
(
self
):
def
field_names
(
self
):
res
=
[
"mit_mot_out_slices"
,
"mit_mot_in_slices"
,
"mit_sot_in_slices"
]
res
=
[
"mit_mot_out_slices"
,
"mit_mot_in_slices"
,
"mit_sot_in_slices"
]
res
.
extend
(
res
.
extend
(
[
attr
attr
for
attr
in
self
.
__dict__
for
attr
in
self
.
__dict__
if
attr
.
startswith
(
"inner_in"
)
if
attr
.
startswith
(
"inner_in"
)
or
attr
.
startswith
(
"inner_out"
)
or
attr
.
startswith
(
"inner_out"
)
or
attr
.
startswith
(
"outer_in"
)
or
attr
.
startswith
(
"outer_in"
)
or
attr
.
startswith
(
"outer_out"
)
or
attr
.
startswith
(
"outer_out"
)
or
attr
==
"n_steps"
or
attr
==
"n_steps"
]
)
)
return
res
return
res
...
...
pytensor/tensor/basic.py
浏览文件 @
10f285a1
...
@@ -1554,7 +1554,7 @@ class Alloc(COp):
...
@@ -1554,7 +1554,7 @@ class Alloc(COp):
def
perform
(
self
,
node
,
inputs
,
out_
):
def
perform
(
self
,
node
,
inputs
,
out_
):
(
out
,)
=
out_
(
out
,)
=
out_
v
=
inputs
[
0
]
v
=
inputs
[
0
]
sh
=
tuple
(
[
int
(
i
)
for
i
in
inputs
[
1
:]
])
sh
=
tuple
(
int
(
i
)
for
i
in
inputs
[
1
:
])
self
.
_check_runtime_broadcast
(
node
,
v
,
sh
)
self
.
_check_runtime_broadcast
(
node
,
v
,
sh
)
if
out
[
0
]
is
None
or
out
[
0
]
.
shape
!=
sh
:
if
out
[
0
]
is
None
or
out
[
0
]
.
shape
!=
sh
:
...
@@ -4180,7 +4180,7 @@ class AllocEmpty(COp):
...
@@ -4180,7 +4180,7 @@ class AllocEmpty(COp):
def
perform
(
self
,
node
,
inputs
,
out_
):
def
perform
(
self
,
node
,
inputs
,
out_
):
(
out
,)
=
out_
(
out
,)
=
out_
sh
=
tuple
(
[
int
(
i
)
for
i
in
inputs
]
)
sh
=
tuple
(
int
(
i
)
for
i
in
inputs
)
if
out
[
0
]
is
None
or
out
[
0
]
.
shape
!=
sh
:
if
out
[
0
]
is
None
or
out
[
0
]
.
shape
!=
sh
:
out
[
0
]
=
np
.
empty
(
sh
,
dtype
=
self
.
dtype
)
out
[
0
]
=
np
.
empty
(
sh
,
dtype
=
self
.
dtype
)
...
...
pytensor/tensor/blas.py
浏览文件 @
10f285a1
...
@@ -1691,7 +1691,7 @@ class BatchedDot(COp):
...
@@ -1691,7 +1691,7 @@ class BatchedDot(COp):
if
x
.
shape
[
0
]
!=
y
.
shape
[
0
]:
if
x
.
shape
[
0
]
!=
y
.
shape
[
0
]:
raise
TypeError
(
raise
TypeError
(
f
"Inputs [{', '.join(map(str, inp))}] must have the"
f
"Inputs [{', '.join(map(str, inp))}] must have the"
f
" same size in axis 0, but have sizes [{', '.join(
[str(i.shape[0]) for i in inp]
)}]."
f
" same size in axis 0, but have sizes [{', '.join(
str(i.shape[0]) for i in inp
)}]."
)
)
z
[
0
]
=
np
.
matmul
(
x
,
y
)
z
[
0
]
=
np
.
matmul
(
x
,
y
)
...
...
pytensor/tensor/blockwise.py
浏览文件 @
10f285a1
...
@@ -139,10 +139,8 @@ class Blockwise(Op):
...
@@ -139,10 +139,8 @@ class Blockwise(Op):
try
:
try
:
batch_shape
=
tuple
(
batch_shape
=
tuple
(
[
broadcast_static_dim_lengths
(
batch_dims
)
broadcast_static_dim_lengths
(
batch_dims
)
for
batch_dims
in
zip
(
*
batch_shapes
)
for
batch_dims
in
zip
(
*
batch_shapes
)
]
)
)
except
ValueError
:
except
ValueError
:
raise
ValueError
(
raise
ValueError
(
...
...
pytensor/tensor/elemwise.py
浏览文件 @
10f285a1
...
@@ -182,7 +182,7 @@ class DimShuffle(ExternalCOp):
...
@@ -182,7 +182,7 @@ class DimShuffle(ExternalCOp):
self
.
transposition
=
self
.
shuffle
+
drop
self
.
transposition
=
self
.
shuffle
+
drop
# List of dimensions of the output that are broadcastable and were not
# List of dimensions of the output that are broadcastable and were not
# in the original input
# in the original input
self
.
augment
=
sorted
(
[
i
for
i
,
x
in
enumerate
(
new_order
)
if
x
==
"x"
]
)
self
.
augment
=
sorted
(
i
for
i
,
x
in
enumerate
(
new_order
)
if
x
==
"x"
)
self
.
drop
=
drop
self
.
drop
=
drop
if
self
.
inplace
:
if
self
.
inplace
:
...
@@ -893,11 +893,9 @@ class Elemwise(OpenMPOp):
...
@@ -893,11 +893,9 @@ class Elemwise(OpenMPOp):
# In that case, create a fortran output ndarray.
# In that case, create a fortran output ndarray.
z
=
list
(
zip
(
inames
,
inputs
))
z
=
list
(
zip
(
inames
,
inputs
))
alloc_fortran
=
" && "
.
join
(
alloc_fortran
=
" && "
.
join
(
[
f
"PyArray_ISFORTRAN({arr})"
f
"PyArray_ISFORTRAN({arr})"
for
arr
,
var
in
z
for
arr
,
var
in
z
if
not
all
(
s
==
1
for
s
in
var
.
type
.
shape
)
if
not
all
(
s
==
1
for
s
in
var
.
type
.
shape
)
]
)
)
# If it is a scalar, make it c contig to prevent problem with
# If it is a scalar, make it c contig to prevent problem with
# NumPy C and F contig not always set as both of them.
# NumPy C and F contig not always set as both of them.
...
@@ -984,12 +982,10 @@ class Elemwise(OpenMPOp):
...
@@ -984,12 +982,10 @@ class Elemwise(OpenMPOp):
if
len
(
all_code
)
==
1
:
if
len
(
all_code
)
==
1
:
# No loops
# No loops
task_decl
=
""
.
join
(
task_decl
=
""
.
join
(
[
f
"{dtype}& {name}_i = *{name}_iter;
\n
"
f
"{dtype}& {name}_i = *{name}_iter;
\n
"
for
name
,
dtype
in
zip
(
for
name
,
dtype
in
zip
(
inames
+
list
(
real_onames
),
idtypes
+
list
(
real_odtypes
)
inames
+
list
(
real_onames
),
idtypes
+
list
(
real_odtypes
)
)
)
]
)
)
preloops
=
{}
preloops
=
{}
...
@@ -1101,18 +1097,14 @@ class Elemwise(OpenMPOp):
...
@@ -1101,18 +1097,14 @@ class Elemwise(OpenMPOp):
z
=
list
(
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
))
z
=
list
(
zip
(
inames
+
onames
,
inputs
+
node
.
outputs
))
all_broadcastable
=
all
(
s
==
1
for
s
in
var
.
type
.
shape
)
all_broadcastable
=
all
(
s
==
1
for
s
in
var
.
type
.
shape
)
cond1
=
" && "
.
join
(
cond1
=
" && "
.
join
(
[
f
"PyArray_ISCONTIGUOUS({arr})"
f
"PyArray_ISCONTIGUOUS({arr})"
for
arr
,
var
in
z
for
arr
,
var
in
z
if
not
all_broadcastable
if
not
all_broadcastable
]
)
)
cond2
=
" && "
.
join
(
cond2
=
" && "
.
join
(
[
f
"PyArray_ISFORTRAN({arr})"
f
"PyArray_ISFORTRAN({arr})"
for
arr
,
var
in
z
for
arr
,
var
in
z
if
not
all_broadcastable
if
not
all_broadcastable
]
)
)
loop
=
"""
loop
=
"""
if(({cond1}) || ({cond2})){{
if(({cond1}) || ({cond2})){{
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
10f285a1
...
@@ -1248,7 +1248,7 @@ class Unique(Op):
...
@@ -1248,7 +1248,7 @@ class Unique(Op):
f
"Unique axis `{self.axis}` is outside of input ndim = {ndim}."
f
"Unique axis `{self.axis}` is outside of input ndim = {ndim}."
)
)
ret
[
0
]
=
tuple
(
ret
[
0
]
=
tuple
(
[
fgraph
.
shape_feature
.
shape_ir
(
i
,
node
.
outputs
[
0
])
for
i
in
range
(
ndim
)]
fgraph
.
shape_feature
.
shape_ir
(
i
,
node
.
outputs
[
0
])
for
i
in
range
(
ndim
)
)
)
if
self
.
return_inverse
:
if
self
.
return_inverse
:
if
self
.
axis
is
None
:
if
self
.
axis
is
None
:
...
...
pytensor/tensor/math.py
浏览文件 @
10f285a1
...
@@ -258,11 +258,9 @@ class Argmax(COp):
...
@@ -258,11 +258,9 @@ class Argmax(COp):
if
self
.
axis
is
None
:
if
self
.
axis
is
None
:
return
[()]
return
[()]
rval
=
tuple
(
rval
=
tuple
(
[
ishape
[
i
]
ishape
[
i
]
for
(
i
,
b
)
in
enumerate
(
node
.
inputs
[
0
]
.
type
.
broadcastable
)
for
(
i
,
b
)
in
enumerate
(
node
.
inputs
[
0
]
.
type
.
broadcastable
)
if
i
not
in
self
.
axis
if
i
not
in
self
.
axis
]
)
)
return
[
rval
]
return
[
rval
]
...
...
pytensor/tensor/shape.py
浏览文件 @
10f285a1
...
@@ -800,10 +800,8 @@ class Reshape(COp):
...
@@ -800,10 +800,8 @@ class Reshape(COp):
rest_size
=
input_size
//
maximum
(
requ_size
,
1
)
rest_size
=
input_size
//
maximum
(
requ_size
,
1
)
return
[
return
[
tuple
(
tuple
(
[
ptb
.
switch
(
eq
(
requ
[
i
],
-
1
),
rest_size
,
requ
[
i
])
ptb
.
switch
(
eq
(
requ
[
i
],
-
1
),
rest_size
,
requ
[
i
])
for
i
in
range
(
self
.
ndim
)
for
i
in
range
(
self
.
ndim
)
]
)
)
]
]
...
...
pytensor/tensor/slinalg.py
浏览文件 @
10f285a1
...
@@ -879,7 +879,7 @@ class BaseBlockDiagonal(Op):
...
@@ -879,7 +879,7 @@ class BaseBlockDiagonal(Op):
__props__
=
(
"n_inputs"
,)
__props__
=
(
"n_inputs"
,)
def
__init__
(
self
,
n_inputs
):
def
__init__
(
self
,
n_inputs
):
input_sig
=
","
.
join
(
[
f
"(m{i},n{i})"
for
i
in
range
(
n_inputs
)]
)
input_sig
=
","
.
join
(
f
"(m{i},n{i})"
for
i
in
range
(
n_inputs
)
)
self
.
gufunc_signature
=
f
"{input_sig}->(m,n)"
self
.
gufunc_signature
=
f
"{input_sig}->(m,n)"
if
n_inputs
==
0
:
if
n_inputs
==
0
:
...
...
pytensor/tensor/subtensor.py
浏览文件 @
10f285a1
...
@@ -1113,7 +1113,7 @@ class Subtensor(COp):
...
@@ -1113,7 +1113,7 @@ class Subtensor(COp):
if
is_slice
:
if
is_slice
:
is_slice_init
=
(
is_slice_init
=
(
"int is_slice[] = {"
+
","
.
join
(
[
str
(
s
)
for
s
in
is_slice
]
)
+
"};"
"int is_slice[] = {"
+
","
.
join
(
str
(
s
)
for
s
in
is_slice
)
+
"};"
)
)
else
:
else
:
is_slice_init
=
"int* is_slice = NULL;"
is_slice_init
=
"int* is_slice = NULL;"
...
@@ -2401,9 +2401,7 @@ class AdvancedIncSubtensor1(COp):
...
@@ -2401,9 +2401,7 @@ class AdvancedIncSubtensor1(COp):
fn_array
=
(
fn_array
=
(
"static inplace_map_binop addition_funcs[] = {"
"static inplace_map_binop addition_funcs[] = {"
+
""
.
join
(
+
""
.
join
(
gen_binop
(
type
=
t
,
typen
=
t
.
upper
())
for
t
in
types
+
complex_types
)
[
gen_binop
(
type
=
t
,
typen
=
t
.
upper
())
for
t
in
types
+
complex_types
]
)
+
"NULL};
\n
"
+
"NULL};
\n
"
)
)
...
@@ -2416,7 +2414,7 @@ class AdvancedIncSubtensor1(COp):
...
@@ -2416,7 +2414,7 @@ class AdvancedIncSubtensor1(COp):
type_number_array
=
(
type_number_array
=
(
"static int type_numbers[] = {"
"static int type_numbers[] = {"
+
""
.
join
(
[
gen_num
(
typen
=
t
.
upper
())
for
t
in
types
+
complex_types
]
)
+
""
.
join
(
gen_num
(
typen
=
t
.
upper
())
for
t
in
types
+
complex_types
)
+
"-1000};"
+
"-1000};"
)
)
...
...
pytensor/tensor/type.py
浏览文件 @
10f285a1
...
@@ -401,7 +401,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
...
@@ -401,7 +401,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
else
:
else
:
return
str
(
s
)
return
str
(
s
)
formatted_shape
=
", "
.
join
(
[
shape_str
(
s
)
for
s
in
shape
]
)
formatted_shape
=
", "
.
join
(
shape_str
(
s
)
for
s
in
shape
)
if
len_shape
==
1
:
if
len_shape
==
1
:
formatted_shape
+=
","
formatted_shape
+=
","
...
...
pytensor/tensor/variable.py
浏览文件 @
10f285a1
...
@@ -521,12 +521,10 @@ class _tensor_py_operators:
...
@@ -521,12 +521,10 @@ class _tensor_py_operators:
# Else leave it as is if it is a real number
# Else leave it as is if it is a real number
# Convert python literals to pytensor constants
# Convert python literals to pytensor constants
args
=
tuple
(
args
=
tuple
(
[
pt
.
subtensor
.
as_index_constant
(
pt
.
subtensor
.
as_index_constant
(
np
.
array
(
inp
,
dtype
=
np
.
uint8
)
if
is_empty_array
(
inp
)
else
inp
np
.
array
(
inp
,
dtype
=
np
.
uint8
)
if
is_empty_array
(
inp
)
else
inp
)
)
for
inp
in
args
for
inp
in
args
]
)
)
# Determine if advanced indexing is needed or not. The logic is
# Determine if advanced indexing is needed or not. The logic is
...
...
tests/tensor/test_math.py
浏览文件 @
10f285a1
...
@@ -3418,7 +3418,7 @@ class TestSumMeanMaxMinArgMaxVarReduceAxes:
...
@@ -3418,7 +3418,7 @@ class TestSumMeanMaxMinArgMaxVarReduceAxes:
def
reduce_bitwise_and
(
x
,
axis
=-
1
,
dtype
=
"int8"
):
def
reduce_bitwise_and
(
x
,
axis
=-
1
,
dtype
=
"int8"
):
identity
=
np
.
array
((
-
1
,),
dtype
=
dtype
)[
0
]
identity
=
np
.
array
((
-
1
,),
dtype
=
dtype
)[
0
]
shape_without_axis
=
tuple
(
[
s
for
i
,
s
in
enumerate
(
x
.
shape
)
if
i
!=
axis
]
)
shape_without_axis
=
tuple
(
s
for
i
,
s
in
enumerate
(
x
.
shape
)
if
i
!=
axis
)
if
0
in
shape_without_axis
:
if
0
in
shape_without_axis
:
return
np
.
empty
(
shape
=
shape_without_axis
,
dtype
=
x
.
dtype
)
return
np
.
empty
(
shape
=
shape_without_axis
,
dtype
=
x
.
dtype
)
...
...
tests/tensor/utils.py
浏览文件 @
10f285a1
...
@@ -667,7 +667,7 @@ def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs):
...
@@ -667,7 +667,7 @@ def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs):
# For instance: sub_inplace -> SubInplace
# For instance: sub_inplace -> SubInplace
capitalize
=
True
capitalize
=
True
if
capitalize
:
if
capitalize
:
name
=
""
.
join
(
[
x
.
capitalize
()
for
x
in
name
.
split
(
"_"
)]
)
name
=
""
.
join
(
x
.
capitalize
()
for
x
in
name
.
split
(
"_"
)
)
# Some tests specify a name that already ends with 'Tester', while in other
# Some tests specify a name that already ends with 'Tester', while in other
# cases we need to add it manually.
# cases we need to add it manually.
if
not
name
.
endswith
(
"Tester"
):
if
not
name
.
endswith
(
"Tester"
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论