Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8a6d2aae
提交
8a6d2aae
authored
6月 21, 2024
作者:
Virgile Andreani
提交者:
Ricardo Vieira
7月 03, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rewrite for/append as list comprehensions
上级
bf73f8a0
隐藏空白字符变更
内嵌
并排
正在显示
21 个修改的文件
包含
113 行增加
和
164 行删除
+113
-164
debugmode.py
pytensor/compile/debugmode.py
+4
-5
__init__.py
pytensor/compile/function/__init__.py
+3
-11
types.py
pytensor/compile/function/types.py
+2
-6
profiling.py
pytensor/compile/profiling.py
+7
-7
basic.py
pytensor/graph/basic.py
+1
-2
fg.py
pytensor/graph/fg.py
+4
-2
basic.py
pytensor/graph/rewriting/basic.py
+4
-4
basic.py
pytensor/link/basic.py
+17
-18
basic.py
pytensor/link/c/basic.py
+15
-17
op.py
pytensor/link/c/op.py
+13
-21
basic.py
pytensor/link/numba/dispatch/basic.py
+10
-12
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+1
-3
scan.py
pytensor/link/numba/dispatch/scan.py
+8
-8
vm.py
pytensor/link/vm.py
+1
-2
basic.py
pytensor/scalar/basic.py
+1
-3
basic.py
pytensor/scan/basic.py
+1
-3
basic.py
pytensor/tensor/basic.py
+7
-16
blockwise.py
pytensor/tensor/blockwise.py
+1
-2
elemwise.py
pytensor/tensor/elemwise.py
+8
-4
subtensor.py
pytensor/tensor/subtensor.py
+4
-16
basic.py
pytensor/typed_list/basic.py
+1
-2
没有找到文件。
pytensor/compile/debugmode.py
浏览文件 @
8a6d2aae
...
...
@@ -906,11 +906,10 @@ def _get_preallocated_maps(
name
=
f
"strided{tuple(steps)}"
for
r
in
considered_outputs
:
if
r
in
init_strided
:
strides
=
[]
shapes
=
[]
for
i
,
size
in
enumerate
(
r_vals
[
r
]
.
shape
):
shapes
.
append
(
slice
(
None
,
size
,
None
))
strides
.
append
(
slice
(
None
,
None
,
steps
[
i
]))
shapes
=
[
slice
(
None
,
size
,
None
)
for
size
in
r_vals
[
r
]
.
shape
]
strides
=
[
slice
(
None
,
None
,
steps
[
i
])
for
i
in
range
(
r_vals
[
r
]
.
ndim
)
]
r_buf
=
init_strided
[
r
]
...
...
pytensor/compile/function/__init__.py
浏览文件 @
8a6d2aae
...
...
@@ -247,18 +247,10 @@ def function(
"""
if
isinstance
(
outputs
,
dict
):
output_items
=
list
(
outputs
.
items
()
)
assert
all
(
isinstance
(
k
,
str
)
for
k
in
outputs
)
for
item_pair
in
output_items
:
assert
isinstance
(
item_pair
[
0
],
str
)
output_items_sorted
=
sorted
(
output_items
)
output_keys
=
[]
outputs
=
[]
for
pair
in
output_items_sorted
:
output_keys
.
append
(
pair
[
0
])
outputs
.
append
(
pair
[
1
])
output_keys
=
sorted
(
outputs
)
outputs
=
[
outputs
[
key
]
for
key
in
output_keys
]
else
:
output_keys
=
None
...
...
pytensor/compile/function/types.py
浏览文件 @
8a6d2aae
...
...
@@ -212,18 +212,14 @@ def std_fgraph(
found_updates
.
extend
(
map
(
SymbolicOutput
,
updates
))
elif
fgraph
is
None
:
input_vars
=
[]
# If one of the inputs is non-atomic (i.e. has a non-`None` `Variable.owner`),
# then we need to create/clone the graph starting at these inputs.
# The result will be atomic versions of the given inputs connected to
# the same outputs.
# Otherwise, when all the inputs are already atomic, there's no need to
# clone the graph.
clone
=
force_clone
for
spec
in
input_specs
:
input_vars
.
append
(
spec
.
variable
)
clone
|=
spec
.
variable
.
owner
is
not
None
input_vars
=
[
spec
.
variable
for
spec
in
input_specs
]
clone
=
force_clone
or
any
(
var
.
owner
is
not
None
for
var
in
input_vars
)
fgraph
=
FunctionGraph
(
input_vars
,
...
...
pytensor/compile/profiling.py
浏览文件 @
8a6d2aae
...
...
@@ -1204,8 +1204,7 @@ class ProfileStats:
compute_map
[
var
][
0
]
=
0
for
k_remove
,
v_remove
in
viewedby_remove
.
items
():
for
i
in
v_remove
:
viewed_by
[
k_remove
]
.
append
(
i
)
viewed_by
[
k_remove
]
.
extend
(
v_remove
)
for
k_add
,
v_add
in
viewedby_add
.
items
():
for
i
in
v_add
:
...
...
@@ -1215,15 +1214,16 @@ class ProfileStats:
del
view_of
[
k
]
# two data structure used to mimic Python gc
viewed_by
=
{}
#
{var1: [vars that view var1]}
# *
{var1: [vars that view var1]}
# The len of the list is the value of python ref
# count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug
in the algo
for
var
in
fgraph
.
variables
:
viewed_by
[
var
]
=
[]
view_of
=
{}
#
{var1: original var viewed by var1}
# This is more safe to help detect potential bug in the algo
viewed_by
=
{
var
:
[]
for
var
in
fgraph
.
variables
}
# *
{var1: original var viewed by var1}
# The original mean that we don't keep track of all the intermediate
# relationship in the view.
view_of
=
{}
min_memory_generator
(
executable_nodes
,
viewed_by
,
view_of
)
...
...
pytensor/graph/basic.py
浏览文件 @
8a6d2aae
...
...
@@ -1474,9 +1474,8 @@ def general_toposort(
_clients
:
dict
[
T
,
list
[
T
]]
=
{}
sources
:
deque
[
T
]
=
deque
()
search_res_len
:
int
=
0
search_res_len
=
len
(
search_res
)
for
snode
,
children
in
search_res
:
search_res_len
+=
1
if
children
:
for
child
in
children
:
_clients
.
setdefault
(
child
,
[])
.
append
(
snode
)
...
...
pytensor/graph/fg.py
浏览文件 @
8a6d2aae
...
...
@@ -270,8 +270,10 @@ class FunctionGraph(MetaObject):
self
.
execute_callbacks
(
"on_prune"
,
apply_node
,
reason
)
for
i
,
in_var
in
enumerate
(
apply_node
.
inputs
):
removal_stack
.
append
((
in_var
,
(
apply_node
,
i
)))
removal_stack
.
extend
(
(
in_var
,
(
apply_node
,
i
))
for
i
,
in_var
in
enumerate
(
apply_node
.
inputs
)
)
if
remove_if_empty
:
del
clients
[
var
]
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
8a6d2aae
...
...
@@ -479,9 +479,9 @@ class SequentialGraphRewriter(GraphRewriter, UserList):
new_sub_profile
.
append
(
p
[
6
][
idx
])
new_rewrite
=
SequentialGraphRewriter
(
*
new_l
)
new_nb_nodes
=
[
]
for
p1
,
p2
in
zip
(
prof1
[
8
],
prof2
[
8
]):
new_nb_nodes
.
append
((
p1
[
0
]
+
p2
[
0
],
p1
[
1
]
+
p2
[
1
]))
new_nb_nodes
=
[
(
p1
[
0
]
+
p2
[
0
],
p1
[
1
]
+
p2
[
1
])
for
p1
,
p2
in
zip
(
prof1
[
8
],
prof2
[
8
])
]
new_nb_nodes
.
extend
(
prof1
[
8
][
len
(
new_nb_nodes
)
:])
new_nb_nodes
.
extend
(
prof2
[
8
][
len
(
new_nb_nodes
)
:])
...
...
@@ -960,9 +960,9 @@ class MetaNodeRewriter(NodeRewriter):
tracks
=
rewriter
.
tracks
()
if
tracks
:
self
.
_tracks
.
extend
(
tracks
)
for
c
in
tracks
:
self
.
track_dict
[
c
]
.
append
(
rewriter
)
self
.
_tracks
.
append
(
c
)
for
tag
in
tag_list
:
self
.
tag_dict
[
tag
]
.
append
(
rewriter
)
...
...
pytensor/link/basic.py
浏览文件 @
8a6d2aae
...
...
@@ -524,12 +524,13 @@ class WrapLinker(Linker):
thunk_groups
=
list
(
zip
(
*
thunk_lists
))
order
=
[
x
[
0
]
for
x
in
zip
(
*
order_lists
)]
to_reset
=
[]
for
thunks
,
node
in
zip
(
thunk_groups
,
order
):
for
j
,
output
in
enumerate
(
node
.
outputs
):
if
output
in
no_recycling
:
for
thunk
in
thunks
:
to_reset
.
append
(
thunk
.
outputs
[
j
])
to_reset
=
[
thunk
.
outputs
[
j
]
for
thunks
,
node
in
zip
(
thunk_groups
,
order
)
for
j
,
output
in
enumerate
(
node
.
outputs
)
if
output
in
no_recycling
for
thunk
in
thunks
]
wrapper
=
self
.
wrapper
pre
=
self
.
pre
...
...
@@ -696,18 +697,16 @@ class JITLinker(PerformLinker):
computed
,
last_user
=
gc_helper
(
nodes
)
if
self
.
allow_gc
:
post_thunk_old_storage
=
[]
for
node
in
nodes
:
post_thunk_old_storage
.
append
(
[
storage_map
[
input
]
for
input
in
node
.
inputs
if
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
(
node
==
last_user
[
input
])
]
)
post_thunk_old_storage
=
[
[
storage_map
[
input
]
for
input
in
node
.
inputs
if
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
(
node
==
last_user
[
input
])
]
for
node
in
nodes
]
else
:
post_thunk_old_storage
=
None
...
...
pytensor/link/c/basic.py
浏览文件 @
8a6d2aae
...
...
@@ -1129,19 +1129,18 @@ class CLinker(Linker):
)
def
get_init_tasks
(
self
):
init_tasks
=
[]
tasks
=
[]
vars
=
[
v
for
v
in
self
.
variables
if
v
not
in
self
.
consts
]
id
=
1
for
v
in
self
.
variables
:
if
v
in
self
.
consts
:
continue
init_tasks
.
append
((
v
,
"init"
,
id
)
)
tasks
.
append
((
v
,
"get"
,
id
+
1
))
id
+=
2
for
node
in
self
.
node_order
:
tasks
.
append
((
node
,
"code"
,
id
))
init_tasks
.
append
((
node
,
"init"
,
id
+
1
)
)
id
+=
2
init_tasks
=
[(
v
,
"init"
,
id
+
2
*
i
)
for
i
,
v
in
enumerate
(
vars
)]
tasks
=
[(
v
,
"get"
,
id
+
2
*
i
+
1
)
for
i
,
v
in
enumerate
(
vars
)]
id
+=
2
*
len
(
vars
)
tasks
.
extend
(
(
node
,
"code"
,
id
+
2
*
i
)
for
i
,
node
in
enumerate
(
self
.
node_order
)
)
init_tasks
.
extend
(
(
node
,
"init"
,
id
+
2
*
i
+
1
)
for
i
,
node
in
enumerate
(
self
.
node_order
)
)
return
init_tasks
,
tasks
def
make_thunk
(
...
...
@@ -1492,12 +1491,11 @@ class CLinker(Linker):
# graph's information used to compute the key. If we mistakenly
# pretend that inputs with clients don't have any, were are only using
# those inputs more than once to compute the key.
for
ipos
,
var
in
[
(
i
,
var
)
for
i
,
var
in
enumerate
(
fgraph
.
inputs
)
sig
.
extend
(
(
var
.
type
,
in_sig
(
var
,
-
1
,
ipos
)
)
for
i
pos
,
var
in
enumerate
(
fgraph
.
inputs
)
if
not
len
(
fgraph
.
clients
[
var
])
]:
sig
.
append
((
var
.
type
,
in_sig
(
var
,
-
1
,
ipos
)))
)
# crystalize the signature and version
sig
=
tuple
(
sig
)
...
...
pytensor/link/c/op.py
浏览文件 @
8a6d2aae
...
...
@@ -220,12 +220,7 @@ int main( int argc, const char* argv[] )
def
lquote_macro
(
txt
:
str
)
->
str
:
"""Turn the last line of text into a ``
\\
``-commented line."""
res
=
[]
spl
=
txt
.
split
(
"
\n
"
)
for
l
in
spl
[:
-
1
]:
res
.
append
(
l
+
"
\\
"
)
res
.
append
(
spl
[
-
1
])
return
"
\n
"
.
join
(
res
)
return
"
\\\n
"
.
join
(
txt
.
split
(
"
\n
"
))
def
get_sub_macros
(
sub
:
dict
[
str
,
str
])
->
tuple
[
str
,
str
]:
...
...
@@ -240,21 +235,17 @@ def get_sub_macros(sub: dict[str, str]) -> tuple[str, str]:
return
"
\n
"
.
join
(
define_macros
),
"
\n
"
.
join
(
undef_macros
)
def
get_io_macros
(
inputs
:
list
[
str
],
outputs
:
list
[
str
]
)
->
tuple
[
list
[
str
]]
|
tuple
[
str
,
str
]:
define_macros
=
[]
undef_macros
=
[]
def
get_io_macros
(
inputs
:
list
[
str
],
outputs
:
list
[
str
])
->
tuple
[
str
,
str
]:
define_inputs
=
[
f
"#define INPUT_{int(i)} {inp}"
for
i
,
inp
in
enumerate
(
inputs
)]
define_outputs
=
[
f
"#define OUTPUT_{int(i)} {out}"
for
i
,
out
in
enumerate
(
outputs
)]
for
i
,
inp
in
enumerate
(
inputs
):
define_macros
.
append
(
f
"#define INPUT_{int(i)} {inp}"
)
undef_macros
.
append
(
f
"#undef INPUT_{int(i)}"
)
undef_inputs
=
[
f
"#undef INPUT_{int(i)}"
for
i
in
range
(
len
(
inputs
))]
undef_outputs
=
[
f
"#undef OUTPUT_{int(i)}"
for
i
in
range
(
len
(
outputs
))]
for
i
,
out
in
enumerate
(
outputs
):
define_macros
.
append
(
f
"#define OUTPUT_{int(i)} {out}"
)
undef_macros
.
append
(
f
"#undef OUTPUT_{int(i)}"
)
define_all
=
"
\n
"
.
join
(
define_inputs
+
define_outputs
)
undef_all
=
"
\n
"
.
join
(
undef_inputs
+
undef_outputs
)
return
"
\n
"
.
join
(
define_macros
),
"
\n
"
.
join
(
undef_macros
)
return
define_all
,
undef_all
class
ExternalCOp
(
COp
):
...
...
@@ -560,9 +551,10 @@ class ExternalCOp(COp):
define_macros
.
append
(
define_template
%
(
"APPLY_SPECIFIC(str)"
,
f
"str##_{name}"
))
undef_macros
.
append
(
undef_template
%
"APPLY_SPECIFIC"
)
for
n
,
v
in
self
.
__get_op_params
():
define_macros
.
append
(
define_template
%
(
n
,
v
))
undef_macros
.
append
(
undef_template
%
(
n
,))
define_macros
.
extend
(
define_template
%
(
n
,
v
)
for
n
,
v
in
self
.
__get_op_params
()
)
undef_macros
.
extend
(
undef_template
%
(
n
,)
for
n
,
_
in
self
.
__get_op_params
())
return
"
\n
"
.
join
(
define_macros
),
"
\n
"
.
join
(
undef_macros
)
...
...
pytensor/link/numba/dispatch/basic.py
浏览文件 @
8a6d2aae
...
...
@@ -131,21 +131,19 @@ def create_numba_signature(
reduce_to_scalar
:
bool
=
False
,
)
->
numba
.
types
.
Type
:
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
input_types
=
[]
for
inp
in
node_or_fgraph
.
inputs
:
input_types
.
append
(
get_numba_type
(
inp
.
type
,
force_scalar
=
force_scalar
,
reduce_to_scalar
=
reduce_to_scalar
)
input_types
=
[
get_numba_type
(
inp
.
type
,
force_scalar
=
force_scalar
,
reduce_to_scalar
=
reduce_to_scalar
)
for
inp
in
node_or_fgraph
.
inputs
]
output_types
=
[]
for
out
in
node_or_fgraph
.
outputs
:
output_types
.
append
(
get_numba_type
(
out
.
type
,
force_scalar
=
force_scalar
,
reduce_to_scalar
=
reduce_to_scalar
)
output_types
=
[
get_numba_type
(
out
.
type
,
force_scalar
=
force_scalar
,
reduce_to_scalar
=
reduce_to_scalar
)
for
out
in
node_or_fgraph
.
outputs
]
if
len
(
output_types
)
>
1
:
return
numba
.
types
.
Tuple
(
output_types
)(
*
input_types
)
...
...
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
8a6d2aae
...
...
@@ -520,9 +520,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
if
length
==
1
and
shape
and
iter_length
!=
1
and
not
allow_bc
:
raise
ValueError
(
"Broadcast not allowed."
)
outputs
=
[]
for
dtype
in
output_dtypes
:
outputs
.
append
(
np
.
empty
(
shape
,
dtype
=
dtype
))
outputs
=
[
np
.
empty
(
shape
,
dtype
=
dtype
)
for
dtype
in
output_dtypes
]
for
idx
in
np
.
ndindex
(
shape
):
vals
=
[
input
[
idx
]
for
input
in
inputs_bc
]
...
...
pytensor/link/numba/dispatch/scan.py
浏览文件 @
8a6d2aae
...
...
@@ -268,15 +268,15 @@ def numba_funcify_Scan(op, node, **kwargs):
output_taps
=
inner_in_names_to_output_taps
.
get
(
outer_in_name
,
[
tap_storage_size
]
)
for
out_tap
in
output_taps
:
inner_out_to_outer_in_stmts
.
append
(
idx_to_str
(
storage_name
,
out_tap
,
size
=
storage_size_name
,
allow_scalar
=
True
,
)
inner_out_to_outer_in_stmts
.
extend
(
idx_to_str
(
storage_name
,
out_tap
,
size
=
storage_size_name
,
allow_scalar
=
True
,
)
for
out_tap
in
output_taps
)
add_output_storage_post_proc_stmt
(
storage_name
,
output_taps
,
storage_size_name
...
...
pytensor/link/vm.py
浏览文件 @
8a6d2aae
...
...
@@ -1111,9 +1111,8 @@ class VMLinker(LocalLinker):
# builds the list of prereqs induced by e.g. destroy_handler
ords
=
self
.
fgraph
.
orderings
()
node_prereqs
=
[]
node_output_size
=
[
]
node_output_size
=
[
0
]
*
len
(
nodes
)
for
i
,
node
in
enumerate
(
nodes
):
node_output_size
.
append
(
0
)
prereq_var_idxs
=
[]
for
prereq_node
in
ords
.
get
(
node
,
[]):
prereq_var_idxs
.
extend
([
vars_idx
[
v
]
for
v
in
prereq_node
.
outputs
])
...
...
pytensor/scalar/basic.py
浏览文件 @
8a6d2aae
...
...
@@ -1575,9 +1575,7 @@ class InRange(LogicalComparison):
def
L_op
(
self
,
inputs
,
outputs
,
gout
):
(
x
,
low
,
hi
)
=
inputs
(
gz
,)
=
gout
grads
=
[]
for
elem
in
[
x
,
low
,
hi
]:
grads
.
append
(
self
.
get_grad
(
elem
))
grads
=
[
self
.
get_grad
(
elem
)
for
elem
in
[
x
,
low
,
hi
]]
return
grads
...
...
pytensor/scan/basic.py
浏览文件 @
8a6d2aae
...
...
@@ -646,9 +646,7 @@ def scan(
# Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes
lengths_vec
=
[]
for
seq
in
scan_seqs
:
lengths_vec
.
append
(
seq
.
shape
[
0
])
lengths_vec
=
[
seq
.
shape
[
0
]
for
seq
in
scan_seqs
]
if
not
isNaN_or_Inf_or_None
(
n_steps
):
# ^ N_steps should also be considered
...
...
pytensor/tensor/basic.py
浏览文件 @
8a6d2aae
...
...
@@ -1629,10 +1629,7 @@ class Alloc(COp):
return
[
node
.
inputs
[
1
:]]
def
connection_pattern
(
self
,
node
):
rval
=
[[
True
]]
for
ipt
in
node
.
inputs
[
1
:]:
rval
.
append
([
False
])
rval
=
[[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
1
:])]
return
rval
...
...
@@ -1859,9 +1856,7 @@ class MakeVector(COp):
if
self
.
dtype
in
discrete_dtypes
:
return
[
ipt
.
zeros_like
()
.
astype
(
config
.
floatX
)
for
ipt
in
inputs
]
grads
=
[]
for
i
,
inp
in
enumerate
(
inputs
):
grads
.
append
(
output_gradients
[
0
][
i
])
grads
=
[
output_gradients
[
0
][
i
]
for
i
in
range
(
len
(
inputs
))]
return
grads
def
R_op
(
self
,
inputs
,
eval_points
):
...
...
@@ -2514,13 +2509,11 @@ class Join(COp):
(
out
,)
=
outputs
fail
=
sub
[
"fail"
]
adtype
=
node
.
inputs
[
0
]
.
type
.
dtype_specs
()[
1
]
copy_to_list
=
[]
for
i
,
inp
in
enumerate
(
tens
):
copy_to_list
.
append
(
f
"""Py_INCREF({inp});
PyList_SetItem(list, {i}, (PyObject*){inp});"""
)
copy_to_list
=
(
f
"""Py_INCREF({inp}); PyList_SetItem(list, {i}, (PyObject*){inp});"""
for
i
,
inp
in
enumerate
(
tens
)
)
copy_inputs_to_list
=
"
\n
"
.
join
(
copy_to_list
)
n
=
len
(
tens
)
...
...
@@ -3442,9 +3435,7 @@ class PermuteRowElements(Op):
shp_x
=
in_shapes
[
0
]
shp_y
=
in_shapes
[
1
]
assert
len
(
shp_x
)
==
len
(
shp_y
)
out_shape
=
[]
for
i
in
range
(
len
(
shp_x
)):
out_shape
.
append
(
maximum
(
shp_x
[
i
],
shp_y
[
i
]))
out_shape
=
[
maximum
(
sx
,
sy
)
for
sx
,
sy
in
zip
(
shp_x
,
shp_y
,
strict
=
True
)]
return
[
out_shape
]
def
grad
(
self
,
inp
,
grads
):
...
...
pytensor/tensor/blockwise.py
浏览文件 @
8a6d2aae
...
...
@@ -167,9 +167,8 @@ class Blockwise(Op):
batch_ndims
=
self
.
batch_ndim
(
node
)
core_dims
:
dict
[
str
,
Any
]
=
{}
batch_shapes
=
[]
batch_shapes
=
[
input_shape
[:
batch_ndims
]
for
input_shape
in
input_shapes
]
for
input_shape
,
sig
in
zip
(
input_shapes
,
self
.
inputs_sig
):
batch_shapes
.
append
(
input_shape
[:
batch_ndims
])
core_shape
=
input_shape
[
batch_ndims
:]
for
core_dim
,
dim_name
in
zip
(
core_shape
,
sig
):
...
...
pytensor/tensor/elemwise.py
浏览文件 @
8a6d2aae
...
...
@@ -1161,8 +1161,10 @@ class Elemwise(OpenMPOp):
],
)
version
.
append
(
self
.
scalar_op
.
c_code_cache_version_apply
(
scalar_node
))
for
i
in
node
.
inputs
+
node
.
outputs
:
version
.
append
(
get_scalar_type
(
dtype
=
i
.
type
.
dtype
)
.
c_code_cache_version
())
version
.
extend
(
get_scalar_type
(
dtype
=
i
.
type
.
dtype
)
.
c_code_cache_version
()
for
i
in
node
.
inputs
+
node
.
outputs
)
version
.
append
((
"openmp"
,
self
.
openmp
))
version
.
append
((
"openmp_elemwise_minsize"
,
config
.
openmp_elemwise_minsize
))
if
all
(
version
):
...
...
@@ -1664,8 +1666,10 @@ class CAReduce(COp):
],
)
version
.
append
(
self
.
scalar_op
.
c_code_cache_version_apply
(
scalar_node
))
for
i
in
node
.
inputs
+
node
.
outputs
:
version
.
append
(
get_scalar_type
(
dtype
=
i
.
type
.
dtype
)
.
c_code_cache_version
())
version
.
extend
(
get_scalar_type
(
dtype
=
i
.
type
.
dtype
)
.
c_code_cache_version
()
for
i
in
node
.
inputs
+
node
.
outputs
)
if
all
(
version
):
return
tuple
(
version
)
else
:
...
...
pytensor/tensor/subtensor.py
浏览文件 @
8a6d2aae
...
...
@@ -952,10 +952,7 @@ class Subtensor(COp):
return
[
first
]
+
[
DisconnectedType
()()]
*
len
(
rest
)
def
connection_pattern
(
self
,
node
):
rval
=
[[
True
]]
for
ipt
in
node
.
inputs
[
1
:]:
rval
.
append
([
False
])
rval
=
[[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
1
:])]
return
rval
...
...
@@ -1963,10 +1960,7 @@ class IncSubtensor(COp):
return
self
(
eval_points
[
0
],
eval_points
[
1
],
*
inputs
[
2
:],
return_list
=
True
)
def
connection_pattern
(
self
,
node
):
rval
=
[[
True
],
[
True
]]
for
ipt
in
node
.
inputs
[
2
:]:
rval
.
append
([
False
])
rval
=
[[
True
],
[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
2
:])]
return
rval
...
...
@@ -2765,10 +2759,7 @@ class AdvancedSubtensor(Op):
out
[
0
]
=
rval
def
connection_pattern
(
self
,
node
):
rval
=
[[
True
]]
for
ipt
in
node
.
inputs
[
1
:]:
rval
.
append
([
False
])
rval
=
[[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
1
:])]
return
rval
...
...
@@ -2912,10 +2903,7 @@ class AdvancedIncSubtensor(Op):
return
[
ishapes
[
0
]]
def
connection_pattern
(
self
,
node
):
rval
=
[[
True
],
[
True
]]
for
ipt
in
node
.
inputs
[
2
:]:
rval
.
append
([
False
])
rval
=
[[
True
],
[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
2
:])]
return
rval
...
...
pytensor/typed_list/basic.py
浏览文件 @
8a6d2aae
...
...
@@ -238,8 +238,7 @@ class Extend(COp):
# need to copy toAppend due to destroy_handler limitation
if
toAppend
:
o
=
out
[
0
]
for
i
in
toAppend
:
o
.
append
(
_lessbroken_deepcopy
(
i
))
o
.
extend
(
_lessbroken_deepcopy
(
i
)
for
i
in
toAppend
)
def
__str__
(
self
):
return
self
.
__class__
.
__name__
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论