Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1da2891c
提交
1da2891c
authored
5月 14, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add flake8-comprehensions plugin
上级
bc878138
显示空白字符变更
内嵌
并排
正在显示
27 个修改的文件
包含
55 行增加
和
60 行删除
+55
-60
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-0
builders.py
pytensor/compile/builders.py
+1
-3
gradient.py
pytensor/gradient.py
+3
-3
basic.py
pytensor/graph/basic.py
+1
-1
replace.py
pytensor/graph/replace.py
+1
-1
basic.py
pytensor/graph/rewriting/basic.py
+2
-2
basic.py
pytensor/link/c/basic.py
+2
-2
params_type.py
pytensor/link/c/params_type.py
+1
-1
elemwise.py
pytensor/link/jax/dispatch/elemwise.py
+1
-1
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+1
-1
printing.py
pytensor/printing.py
+1
-1
basic.py
pytensor/scan/basic.py
+2
-2
op.py
pytensor/scan/op.py
+5
-8
rewriting.py
pytensor/scan/rewriting.py
+1
-1
op.py
pytensor/tensor/random/op.py
+1
-1
math.py
pytensor/tensor/rewriting/math.py
+3
-3
shape.py
pytensor/tensor/shape.py
+3
-3
subtensor.py
pytensor/tensor/subtensor.py
+1
-1
setup.cfg
setup.cfg
+1
-1
test_features.py
tests/graph/test_features.py
+2
-2
test_op.py
tests/graph/test_op.py
+1
-1
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+1
-1
test_math.py
tests/tensor/rewriting/test_math.py
+3
-3
test_subtensor.py
tests/tensor/rewriting/test_subtensor.py
+2
-2
test_blas.py
tests/tensor/test_blas.py
+3
-3
test_complex.py
tests/tensor/test_complex.py
+1
-3
test_elemwise.py
tests/tensor/test_elemwise.py
+9
-9
没有找到文件。
.pre-commit-config.yaml
浏览文件 @
1da2891c
...
@@ -33,6 +33,8 @@ repos:
...
@@ -33,6 +33,8 @@ repos:
rev
:
6.0.0
rev
:
6.0.0
hooks
:
hooks
:
-
id
:
flake8
-
id
:
flake8
additional_dependencies
:
-
flake8-comprehensions
-
repo
:
https://github.com/pycqa/isort
-
repo
:
https://github.com/pycqa/isort
rev
:
5.12.0
rev
:
5.12.0
hooks
:
hooks
:
...
...
pytensor/compile/builders.py
浏览文件 @
1da2891c
...
@@ -969,9 +969,7 @@ def inline_ofg_expansion(fgraph, node):
...
@@ -969,9 +969,7 @@ def inline_ofg_expansion(fgraph, node):
return
False
return
False
if
not
op
.
is_inline
:
if
not
op
.
is_inline
:
return
False
return
False
return
clone_replace
(
return
clone_replace
(
op
.
inner_outputs
,
dict
(
zip
(
op
.
inner_inputs
,
node
.
inputs
)))
op
.
inner_outputs
,
{
u
:
v
for
u
,
v
in
zip
(
op
.
inner_inputs
,
node
.
inputs
)}
)
# We want to run this before the first merge optimizer
# We want to run this before the first merge optimizer
...
...
pytensor/gradient.py
浏览文件 @
1da2891c
...
@@ -504,7 +504,7 @@ def grad(
...
@@ -504,7 +504,7 @@ def grad(
if
not
isinstance
(
wrt
,
Sequence
):
if
not
isinstance
(
wrt
,
Sequence
):
_wrt
:
List
[
Variable
]
=
[
wrt
]
_wrt
:
List
[
Variable
]
=
[
wrt
]
else
:
else
:
_wrt
=
[
x
for
x
in
wrt
]
_wrt
=
list
(
wrt
)
outputs
=
[]
outputs
=
[]
if
cost
is
not
None
:
if
cost
is
not
None
:
...
@@ -791,8 +791,8 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
...
@@ -791,8 +791,8 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
pgrads
=
dict
(
zip
(
params
,
grads
))
pgrads
=
dict
(
zip
(
params
,
grads
))
# separate wrt from end grads:
# separate wrt from end grads:
wrt_grads
=
list
(
pgrads
[
k
]
for
k
in
wrt
)
wrt_grads
=
[
pgrads
[
k
]
for
k
in
wrt
]
end_grads
=
list
(
pgrads
[
k
]
for
k
in
end
)
end_grads
=
[
pgrads
[
k
]
for
k
in
end
]
if
details
:
if
details
:
return
wrt_grads
,
end_grads
,
start_grads
,
cost_grads
return
wrt_grads
,
end_grads
,
start_grads
,
cost_grads
...
...
pytensor/graph/basic.py
浏览文件 @
1da2891c
...
@@ -1629,7 +1629,7 @@ def as_string(
...
@@ -1629,7 +1629,7 @@ def as_string(
multi
.
add
(
op2
)
multi
.
add
(
op2
)
else
:
else
:
seen
.
add
(
input
.
owner
)
seen
.
add
(
input
.
owner
)
multi_list
=
[
x
for
x
in
multi
]
multi_list
=
list
(
multi
)
done
:
Set
=
set
()
done
:
Set
=
set
()
def
multi_index
(
x
):
def
multi_index
(
x
):
...
...
pytensor/graph/replace.py
浏览文件 @
1da2891c
...
@@ -142,7 +142,7 @@ def graph_replace(
...
@@ -142,7 +142,7 @@ def graph_replace(
raise
ValueError
(
f
"{key} is not a part of graph"
)
raise
ValueError
(
f
"{key} is not a part of graph"
)
sorted_replacements
=
sorted
(
sorted_replacements
=
sorted
(
tuple
(
fg_replace
.
items
()
),
fg_replace
.
items
(
),
# sort based on the fg toposort, if a variable has no owner, it goes first
# sort based on the fg toposort, if a variable has no owner, it goes first
key
=
partial
(
toposort_key
,
fg
,
toposort
),
key
=
partial
(
toposort_key
,
fg
,
toposort
),
reverse
=
True
,
reverse
=
True
,
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
1da2891c
...
@@ -2575,8 +2575,8 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
...
@@ -2575,8 +2575,8 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
for
i
in
range
(
len
(
loop_timing
)):
for
i
in
range
(
len
(
loop_timing
)):
loop_times
=
""
loop_times
=
""
if
loop_process_count
[
i
]:
if
loop_process_count
[
i
]:
d
=
list
(
d
=
sorted
(
reversed
(
sorted
(
loop_process_count
[
i
]
.
items
(),
key
=
lambda
a
:
a
[
1
]))
loop_process_count
[
i
]
.
items
(),
key
=
lambda
a
:
a
[
1
],
reverse
=
True
)
)
loop_times
=
" "
.
join
([
str
((
str
(
k
),
v
))
for
k
,
v
in
d
[:
5
]])
loop_times
=
" "
.
join
([
str
((
str
(
k
),
v
))
for
k
,
v
in
d
[:
5
]])
if
len
(
d
)
>
5
:
if
len
(
d
)
>
5
:
...
...
pytensor/link/c/basic.py
浏览文件 @
1da2891c
...
@@ -633,11 +633,11 @@ class CLinker(Linker):
...
@@ -633,11 +633,11 @@ class CLinker(Linker):
# The orphans field is listified to ensure a consistent order.
# The orphans field is listified to ensure a consistent order.
# list(fgraph.orphans.difference(self.outputs))
# list(fgraph.orphans.difference(self.outputs))
self
.
orphans
=
list
(
self
.
orphans
=
[
r
r
for
r
in
self
.
variables
for
r
in
self
.
variables
if
isinstance
(
r
,
AtomicVariable
)
and
r
not
in
self
.
inputs
if
isinstance
(
r
,
AtomicVariable
)
and
r
not
in
self
.
inputs
)
]
# C type constants (pytensor.scalar.ScalarType). They don't request an object
# C type constants (pytensor.scalar.ScalarType). They don't request an object
self
.
consts
=
[]
self
.
consts
=
[]
# Move c type from orphans (pytensor.scalar.ScalarType) to self.consts
# Move c type from orphans (pytensor.scalar.ScalarType) to self.consts
...
...
pytensor/link/c/params_type.py
浏览文件 @
1da2891c
...
@@ -810,7 +810,7 @@ class ParamsType(CType):
...
@@ -810,7 +810,7 @@ class ParamsType(CType):
struct_extract_method
=
struct_extract_method
,
struct_extract_method
=
struct_extract_method
,
)
)
return
list
(
sorted
(
list
(
c_support_code_set
))
)
+
[
final_struct_code
]
return
sorted
(
c_support_code_set
)
+
[
final_struct_code
]
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
((
3
,),
tuple
(
t
.
c_code_cache_version
()
for
t
in
self
.
types
))
return
((
3
,),
tuple
(
t
.
c_code_cache_version
()
for
t
in
self
.
types
))
...
...
pytensor/link/jax/dispatch/elemwise.py
浏览文件 @
1da2891c
...
@@ -41,7 +41,7 @@ def jax_funcify_CAReduce(op, **kwargs):
...
@@ -41,7 +41,7 @@ def jax_funcify_CAReduce(op, **kwargs):
elif
scalar_op_name
:
elif
scalar_op_name
:
scalar_fn_name
=
scalar_op_name
scalar_fn_name
=
scalar_op_name
to_reduce
=
reversed
(
sorted
(
axis
)
)
to_reduce
=
sorted
(
axis
,
reverse
=
True
)
if
to_reduce
:
if
to_reduce
:
# In this case, we need to use the `jax.lax` function (if there
# In this case, we need to use the `jax.lax` function (if there
...
...
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
1da2891c
...
@@ -361,7 +361,7 @@ def create_multiaxis_reducer(
...
@@ -361,7 +361,7 @@ def create_multiaxis_reducer(
careduce_fn_name
=
f
"careduce_{scalar_op}"
careduce_fn_name
=
f
"careduce_{scalar_op}"
global_env
=
{}
global_env
=
{}
to_reduce
=
reversed
(
sorted
(
axes
)
)
to_reduce
=
sorted
(
axes
,
reverse
=
True
)
careduce_lines_src
=
[]
careduce_lines_src
=
[]
var_name
=
input_name
var_name
=
input_name
...
...
pytensor/printing.py
浏览文件 @
1da2891c
...
@@ -796,7 +796,7 @@ class Print(Op):
...
@@ -796,7 +796,7 @@ class Print(Op):
return
output_gradients
return
output_gradients
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
return
[
x
for
x
in
eval_points
]
return
list
(
eval_points
)
def
__setstate__
(
self
,
dct
):
def
__setstate__
(
self
,
dct
):
dct
.
setdefault
(
"global_fn"
,
_print_fn
)
dct
.
setdefault
(
"global_fn"
,
_print_fn
)
...
...
pytensor/scan/basic.py
浏览文件 @
1da2891c
...
@@ -492,7 +492,7 @@ def scan(
...
@@ -492,7 +492,7 @@ def scan(
# wrap sequences in a dictionary if they are not already dictionaries
# wrap sequences in a dictionary if they are not already dictionaries
for
i
in
range
(
n_seqs
):
for
i
in
range
(
n_seqs
):
if
not
isinstance
(
seqs
[
i
],
dict
):
if
not
isinstance
(
seqs
[
i
],
dict
):
seqs
[
i
]
=
dict
([(
"input"
,
seqs
[
i
]),
(
"taps"
,
[
0
])])
seqs
[
i
]
=
{
"input"
:
seqs
[
i
],
"taps"
:
[
0
]}
elif
seqs
[
i
]
.
get
(
"taps"
,
None
)
is
not
None
:
elif
seqs
[
i
]
.
get
(
"taps"
,
None
)
is
not
None
:
seqs
[
i
][
"taps"
]
=
wrap_into_list
(
seqs
[
i
][
"taps"
])
seqs
[
i
][
"taps"
]
=
wrap_into_list
(
seqs
[
i
][
"taps"
])
elif
seqs
[
i
]
.
get
(
"taps"
,
None
)
is
None
:
elif
seqs
[
i
]
.
get
(
"taps"
,
None
)
is
None
:
...
@@ -504,7 +504,7 @@ def scan(
...
@@ -504,7 +504,7 @@ def scan(
if
outs_info
[
i
]
is
not
None
:
if
outs_info
[
i
]
is
not
None
:
if
not
isinstance
(
outs_info
[
i
],
dict
):
if
not
isinstance
(
outs_info
[
i
],
dict
):
# by default any output has a tap value of -1
# by default any output has a tap value of -1
outs_info
[
i
]
=
dict
([(
"initial"
,
outs_info
[
i
]),
(
"taps"
,
[
-
1
])])
outs_info
[
i
]
=
{
"initial"
:
outs_info
[
i
],
"taps"
:
[
-
1
]}
elif
(
elif
(
outs_info
[
i
]
.
get
(
"initial"
,
None
)
is
None
outs_info
[
i
]
.
get
(
"initial"
,
None
)
is
None
and
outs_info
[
i
]
.
get
(
"taps"
,
None
)
is
not
None
and
outs_info
[
i
]
.
get
(
"taps"
,
None
)
is
not
None
...
...
pytensor/scan/op.py
浏览文件 @
1da2891c
...
@@ -1718,12 +1718,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1718,12 +1718,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
arg
.
shape
[
0
]
arg
.
shape
[
0
]
for
arg
in
inputs
[
self
.
seqs_arg_offset
:
self
.
shared_arg_offset
]
for
arg
in
inputs
[
self
.
seqs_arg_offset
:
self
.
shared_arg_offset
]
]
]
store_steps
+=
[
store_steps
+=
list
(
arg
inputs
[
self
.
nit_sot_arg_offset
:
self
.
nit_sot_arg_offset
+
info
.
n_nit_sot
]
for
arg
in
inputs
[
)
self
.
nit_sot_arg_offset
:
self
.
nit_sot_arg_offset
+
info
.
n_nit_sot
]
]
# 2.1 Create storage space for outputs
# 2.1 Create storage space for outputs
for
idx
in
range
(
self
.
n_outs
):
for
idx
in
range
(
self
.
n_outs
):
...
@@ -2270,7 +2267,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2270,7 +2267,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
)
offset
=
1
+
info
.
n_seqs
offset
=
1
+
info
.
n_seqs
scan_outs
=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
n_outs
]]
scan_outs
=
list
(
input_shapes
[
offset
:
offset
+
n_outs
])
offset
+=
n_outs
offset
+=
n_outs
outs_shape_n
=
info
.
n_mit_mot_outs
+
info
.
n_mit_sot
+
info
.
n_sit_sot
outs_shape_n
=
info
.
n_mit_mot_outs
+
info
.
n_mit_sot
+
info
.
n_sit_sot
for
x
in
range
(
info
.
n_nit_sot
):
for
x
in
range
(
info
.
n_nit_sot
):
...
@@ -2301,7 +2298,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2301,7 +2298,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
shp
.
append
(
v_shp_i
[
0
])
shp
.
append
(
v_shp_i
[
0
])
scan_outs
.
append
(
tuple
(
shp
))
scan_outs
.
append
(
tuple
(
shp
))
scan_outs
+=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
info
.
n_shared_outs
]]
scan_outs
+=
list
(
input_shapes
[
offset
:
offset
+
info
.
n_shared_outs
])
# if we are dealing with a repeat-until, then we do not know the
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
# leading dimension so we replace it for every entry with Shape_i
if
info
.
as_while
:
if
info
.
as_while
:
...
...
pytensor/scan/rewriting.py
浏览文件 @
1da2891c
...
@@ -388,7 +388,7 @@ def push_out_non_seq_scan(fgraph, node):
...
@@ -388,7 +388,7 @@ def push_out_non_seq_scan(fgraph, node):
if
out
in
local_fgraph_outs_set
:
if
out
in
local_fgraph_outs_set
:
x
=
node
.
outputs
[
local_fgraph_outs_map
[
out
]]
x
=
node
.
outputs
[
local_fgraph_outs_map
[
out
]]
y
=
replace_with_out
[
idx
]
y
=
replace_with_out
[
idx
]
y_shape
=
[
shp
for
shp
in
y
.
shape
]
y_shape
=
list
(
y
.
shape
)
replace_with
[
x
]
=
at
.
alloc
(
y
,
node
.
inputs
[
0
],
*
y_shape
)
replace_with
[
x
]
=
at
.
alloc
(
y
,
node
.
inputs
[
0
],
*
y_shape
)
# We need to add one extra dimension to the outputs
# We need to add one extra dimension to the outputs
...
...
pytensor/tensor/random/op.py
浏览文件 @
1da2891c
...
@@ -283,7 +283,7 @@ class RandomVariable(Op):
...
@@ -283,7 +283,7 @@ class RandomVariable(Op):
shape
=
self
.
_infer_shape
(
size
,
dist_params
,
param_shapes
=
param_shapes
)
shape
=
self
.
_infer_shape
(
size
,
dist_params
,
param_shapes
=
param_shapes
)
return
[
None
,
[
s
for
s
in
shape
]
]
return
[
None
,
list
(
shape
)
]
def
__call__
(
self
,
*
args
,
size
=
None
,
name
=
None
,
rng
=
None
,
dtype
=
None
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
size
=
None
,
name
=
None
,
rng
=
None
,
dtype
=
None
,
**
kwargs
):
res
=
super
()
.
__call__
(
rng
,
size
,
dtype
,
*
args
,
**
kwargs
)
res
=
super
()
.
__call__
(
rng
,
size
,
dtype
,
*
args
,
**
kwargs
)
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
1da2891c
...
@@ -1555,11 +1555,11 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
...
@@ -1555,11 +1555,11 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
)
)
if
len
(
compatible_dims
)
>
0
:
if
len
(
compatible_dims
)
>
0
:
optimized_dimshuffle_order
=
list
(
optimized_dimshuffle_order
=
[
ax
ax
for
i
,
ax
in
enumerate
(
dimshuffle_order
)
for
i
,
ax
in
enumerate
(
dimshuffle_order
)
if
(
i
not
in
axis
)
or
(
ax
!=
"x"
)
if
(
i
not
in
axis
)
or
(
ax
!=
"x"
)
)
]
# Removing leading 'x' (since it will be done automatically)
# Removing leading 'x' (since it will be done automatically)
while
(
while
(
...
@@ -1644,7 +1644,7 @@ def local_op_of_op(fgraph, node):
...
@@ -1644,7 +1644,7 @@ def local_op_of_op(fgraph, node):
return
[
op_type
(
None
,
dtype
=
out_dtype
)(
node_inps
.
owner
.
inputs
[
0
])]
return
[
op_type
(
None
,
dtype
=
out_dtype
)(
node_inps
.
owner
.
inputs
[
0
])]
# figure out which axes were in the original sum
# figure out which axes were in the original sum
newaxis
=
list
(
tuple
(
node_inps
.
owner
.
op
.
axis
)
)
newaxis
=
list
(
node_inps
.
owner
.
op
.
axis
)
for
i
in
node
.
op
.
axis
:
for
i
in
node
.
op
.
axis
:
new_i
=
i
new_i
=
i
for
ii
in
node_inps
.
owner
.
op
.
axis
:
for
ii
in
node_inps
.
owner
.
op
.
axis
:
...
...
pytensor/tensor/shape.py
浏览文件 @
1da2891c
...
@@ -810,7 +810,7 @@ def shape_padleft(t, n_ones=1):
...
@@ -810,7 +810,7 @@ def shape_padleft(t, n_ones=1):
"""
"""
_t
=
at
.
as_tensor_variable
(
t
)
_t
=
at
.
as_tensor_variable
(
t
)
pattern
=
[
"x"
]
*
n_ones
+
[
i
for
i
in
range
(
_t
.
type
.
ndim
)]
pattern
=
[
"x"
]
*
n_ones
+
list
(
range
(
_t
.
type
.
ndim
))
return
_t
.
dimshuffle
(
pattern
)
return
_t
.
dimshuffle
(
pattern
)
...
@@ -826,7 +826,7 @@ def shape_padright(t, n_ones=1):
...
@@ -826,7 +826,7 @@ def shape_padright(t, n_ones=1):
"""
"""
_t
=
at
.
as_tensor_variable
(
t
)
_t
=
at
.
as_tensor_variable
(
t
)
pattern
=
[
i
for
i
in
range
(
_t
.
type
.
ndim
)]
+
[
"x"
]
*
n_ones
pattern
=
list
(
range
(
_t
.
type
.
ndim
))
+
[
"x"
]
*
n_ones
return
_t
.
dimshuffle
(
pattern
)
return
_t
.
dimshuffle
(
pattern
)
...
@@ -861,7 +861,7 @@ def shape_padaxis(t, axis):
...
@@ -861,7 +861,7 @@ def shape_padaxis(t, axis):
if
axis
<
0
:
if
axis
<
0
:
axis
+=
ndim
axis
+=
ndim
pattern
=
[
i
for
i
in
range
(
_t
.
type
.
ndim
)]
pattern
=
list
(
range
(
_t
.
type
.
ndim
))
pattern
.
insert
(
axis
,
"x"
)
pattern
.
insert
(
axis
,
"x"
)
return
_t
.
dimshuffle
(
pattern
)
return
_t
.
dimshuffle
(
pattern
)
...
...
pytensor/tensor/subtensor.py
浏览文件 @
1da2891c
...
@@ -2604,7 +2604,7 @@ class AdvancedSubtensor(Op):
...
@@ -2604,7 +2604,7 @@ class AdvancedSubtensor(Op):
ishapes
[
0
],
index_shapes
,
indices_are_shapes
=
True
ishapes
[
0
],
index_shapes
,
indices_are_shapes
=
True
)
)
assert
node
.
outputs
[
0
]
.
ndim
==
len
(
res_shape
)
assert
node
.
outputs
[
0
]
.
ndim
==
len
(
res_shape
)
return
[
[
s
for
s
in
res_shape
]
]
return
[
list
(
res_shape
)
]
def
perform
(
self
,
node
,
inputs
,
out_
):
def
perform
(
self
,
node
,
inputs
,
out_
):
(
out
,)
=
out_
(
out
,)
=
out_
...
...
setup.cfg
浏览文件 @
1da2891c
[flake8]
[flake8]
select = C,E,F,W
select = C,E,F,W
ignore = E203,E231,E501,E741,W503,W504,C901
ignore = E203,E231,E501,E741,W503,W504,C
408,C
901
per-file-ignores =
per-file-ignores =
**/__init__.py:F401,E402,F403
**/__init__.py:F401,E402,F403
pytensor/tensor/linalg.py:F401,F403
pytensor/tensor/linalg.py:F401,F403
...
...
tests/graph/test_features.py
浏览文件 @
1da2891c
...
@@ -73,7 +73,7 @@ class TestNodeFinder:
...
@@ -73,7 +73,7 @@ class TestNodeFinder:
assert
hasattr
(
g
,
"get_nodes"
)
assert
hasattr
(
g
,
"get_nodes"
)
for
type
,
num
in
((
add
,
3
),
(
sigmoid
,
3
),
(
dot
,
2
)):
for
type
,
num
in
((
add
,
3
),
(
sigmoid
,
3
),
(
dot
,
2
)):
if
len
(
[
t
for
t
in
g
.
get_nodes
(
type
)]
)
!=
num
:
if
len
(
list
(
g
.
get_nodes
(
type
))
)
!=
num
:
raise
Exception
(
"Expected:
%
i times
%
s"
%
(
num
,
type
))
raise
Exception
(
"Expected:
%
i times
%
s"
%
(
num
,
type
))
new_e0
=
add
(
y
,
z
)
new_e0
=
add
(
y
,
z
)
assert
e0
.
owner
in
g
.
get_nodes
(
dot
)
assert
e0
.
owner
in
g
.
get_nodes
(
dot
)
...
@@ -82,7 +82,7 @@ class TestNodeFinder:
...
@@ -82,7 +82,7 @@ class TestNodeFinder:
assert
e0
.
owner
not
in
g
.
get_nodes
(
dot
)
assert
e0
.
owner
not
in
g
.
get_nodes
(
dot
)
assert
new_e0
.
owner
in
g
.
get_nodes
(
add
)
assert
new_e0
.
owner
in
g
.
get_nodes
(
add
)
for
type
,
num
in
((
add
,
4
),
(
sigmoid
,
3
),
(
dot
,
1
)):
for
type
,
num
in
((
add
,
4
),
(
sigmoid
,
3
),
(
dot
,
1
)):
if
len
(
[
t
for
t
in
g
.
get_nodes
(
type
)]
)
!=
num
:
if
len
(
list
(
g
.
get_nodes
(
type
))
)
!=
num
:
raise
Exception
(
"Expected:
%
i times
%
s"
%
(
num
,
type
))
raise
Exception
(
"Expected:
%
i times
%
s"
%
(
num
,
type
))
...
...
tests/graph/test_op.py
浏览文件 @
1da2891c
...
@@ -87,7 +87,7 @@ class TestOp:
...
@@ -87,7 +87,7 @@ class TestOp:
r1
,
r2
=
MyType
(
1
)(),
MyType
(
2
)()
r1
,
r2
=
MyType
(
1
)(),
MyType
(
2
)()
node
=
MyOp
.
make_node
(
r1
,
r2
)
node
=
MyOp
.
make_node
(
r1
,
r2
)
# Are the inputs what I provided?
# Are the inputs what I provided?
assert
[
x
for
x
in
node
.
inputs
]
==
[
r1
,
r2
]
assert
list
(
node
.
inputs
)
==
[
r1
,
r2
]
# Are the outputs what I expect?
# Are the outputs what I expect?
assert
[
x
.
type
for
x
in
node
.
outputs
]
==
[
MyType
(
3
)]
assert
[
x
.
type
for
x
in
node
.
outputs
]
==
[
MyType
(
3
)]
assert
node
.
outputs
[
0
]
.
owner
is
node
and
node
.
outputs
[
0
]
.
index
==
0
assert
node
.
outputs
[
0
]
.
owner
is
node
and
node
.
outputs
[
0
]
.
index
==
0
...
...
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
1da2891c
...
@@ -1123,7 +1123,7 @@ class TestFusion:
...
@@ -1123,7 +1123,7 @@ class TestFusion:
out
=
dot
(
x
,
y
)
+
x
+
y
+
z
out
=
dot
(
x
,
y
)
+
x
+
y
+
z
f
=
function
([
x
,
y
,
z
],
out
,
mode
=
self
.
mode
)
f
=
function
([
x
,
y
,
z
],
out
,
mode
=
self
.
mode
)
topo
=
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()]
topo
=
list
(
f
.
maker
.
fgraph
.
toposort
())
assert
len
(
topo
)
==
2
assert
len
(
topo
)
==
2
assert
topo
[
-
1
]
.
op
.
inplace_pattern
assert
topo
[
-
1
]
.
op
.
inplace_pattern
...
...
tests/tensor/rewriting/test_math.py
浏览文件 @
1da2891c
...
@@ -3994,9 +3994,9 @@ class TestSigmoidUtils:
...
@@ -3994,9 +3994,9 @@ class TestSigmoidUtils:
exp_op
=
exp
exp_op
=
exp
assert
is_1pexp
(
1
+
exp_op
(
x
),
False
)
==
(
False
,
x
)
assert
is_1pexp
(
1
+
exp_op
(
x
),
False
)
==
(
False
,
x
)
assert
is_1pexp
(
exp_op
(
x
)
+
1
,
False
)
==
(
False
,
x
)
assert
is_1pexp
(
exp_op
(
x
)
+
1
,
False
)
==
(
False
,
x
)
for
neg_
,
exp_arg
in
map
(
for
neg_
,
exp_arg
in
(
lambda
x
:
is_1pexp
(
x
,
only_process_constants
=
False
),
is_1pexp
(
x
,
only_process_constants
=
False
)
[(
1
+
exp_op
(
-
x
)),
(
exp_op
(
-
x
)
+
1
)],
for
x
in
[(
1
+
exp_op
(
-
x
)),
(
exp_op
(
-
x
)
+
1
)]
):
):
assert
not
neg_
and
is_same_graph
(
exp_arg
,
-
x
)
assert
not
neg_
and
is_same_graph
(
exp_arg
,
-
x
)
assert
is_1pexp
(
1
-
exp_op
(
x
),
False
)
is
None
assert
is_1pexp
(
1
-
exp_op
(
x
),
False
)
is
None
...
...
tests/tensor/rewriting/test_subtensor.py
浏览文件 @
1da2891c
...
@@ -2004,7 +2004,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
...
@@ -2004,7 +2004,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y_val_fn
=
function
(
y_val_fn
=
function
(
[
x
]
+
list
(
s
),
y
,
on_unused_input
=
"ignore"
,
mode
=
no_rewrites_mode
[
x
]
+
list
(
s
),
y
,
on_unused_input
=
"ignore"
,
mode
=
no_rewrites_mode
)
)
y_val
=
y_val_fn
(
*
([
x_val
]
+
[
s_
for
s_
in
s_val
]
))
y_val
=
y_val_fn
(
*
([
x_val
]
+
list
(
s_val
)
))
# This optimization should appear in the canonicalizations
# This optimization should appear in the canonicalizations
y_opt
=
rewrite_graph
(
y
,
clone
=
False
)
y_opt
=
rewrite_graph
(
y
,
clone
=
False
)
...
@@ -2017,7 +2017,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
...
@@ -2017,7 +2017,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
assert
isinstance
(
y_opt
.
owner
.
op
,
SpecifyShape
)
assert
isinstance
(
y_opt
.
owner
.
op
,
SpecifyShape
)
y_opt_fn
=
function
([
x
]
+
list
(
s
),
y_opt
,
on_unused_input
=
"ignore"
)
y_opt_fn
=
function
([
x
]
+
list
(
s
),
y_opt
,
on_unused_input
=
"ignore"
)
y_opt_val
=
y_opt_fn
(
*
([
x_val
]
+
[
s_
for
s_
in
s_val
]
))
y_opt_val
=
y_opt_fn
(
*
([
x_val
]
+
list
(
s_val
)
))
assert
np
.
allclose
(
y_val
,
y_opt_val
)
assert
np
.
allclose
(
y_val
,
y_opt_val
)
...
...
tests/tensor/test_blas.py
浏览文件 @
1da2891c
...
@@ -2589,10 +2589,10 @@ TestBatchedDot = makeTester(
...
@@ -2589,10 +2589,10 @@ TestBatchedDot = makeTester(
op
=
batched_dot
,
op
=
batched_dot
,
expected
=
(
expected
=
(
lambda
xs
,
ys
:
np
.
asarray
(
lambda
xs
,
ys
:
np
.
asarray
(
list
(
[
x
*
y
if
x
.
ndim
==
0
or
y
.
ndim
==
0
else
np
.
dot
(
x
,
y
)
x
*
y
if
x
.
ndim
==
0
or
y
.
ndim
==
0
else
np
.
dot
(
x
,
y
)
for
x
,
y
in
zip
(
xs
,
ys
)
for
x
,
y
in
zip
(
xs
,
ys
)
)
,
]
,
dtype
=
aes
.
upcast
(
xs
.
dtype
,
ys
.
dtype
),
dtype
=
aes
.
upcast
(
xs
.
dtype
,
ys
.
dtype
),
)
)
),
),
...
@@ -2694,7 +2694,7 @@ def test_batched_dot_not_contiguous():
...
@@ -2694,7 +2694,7 @@ def test_batched_dot_not_contiguous():
assert
x
.
strides
[
0
]
==
direction
*
np
.
dtype
(
config
.
floatX
)
.
itemsize
assert
x
.
strides
[
0
]
==
direction
*
np
.
dtype
(
config
.
floatX
)
.
itemsize
assert
not
(
x
.
flags
[
"C_CONTIGUOUS"
]
or
x
.
flags
[
"F_CONTIGUOUS"
])
assert
not
(
x
.
flags
[
"C_CONTIGUOUS"
]
or
x
.
flags
[
"F_CONTIGUOUS"
])
result
=
f
(
x
,
w
)
result
=
f
(
x
,
w
)
ref_result
=
np
.
asarray
(
list
(
np
.
dot
(
u
,
v
)
for
u
,
v
in
zip
(
x
,
w
))
)
ref_result
=
np
.
asarray
(
[
np
.
dot
(
u
,
v
)
for
u
,
v
in
zip
(
x
,
w
)]
)
utt
.
assert_allclose
(
ref_result
,
result
)
utt
.
assert_allclose
(
ref_result
,
result
)
for
inverted
in
(
0
,
1
):
for
inverted
in
(
0
,
1
):
...
...
tests/tensor/test_complex.py
浏览文件 @
1da2891c
...
@@ -15,9 +15,7 @@ class TestRealImag:
...
@@ -15,9 +15,7 @@ class TestRealImag:
x
=
zvector
()
x
=
zvector
()
rng
=
np
.
random
.
default_rng
(
23
)
rng
=
np
.
random
.
default_rng
(
23
)
xval
=
np
.
asarray
(
xval
=
np
.
asarray
(
list
(
[
complex
(
rng
.
standard_normal
(),
rng
.
standard_normal
())
for
i
in
range
(
10
)]
complex
(
rng
.
standard_normal
(),
rng
.
standard_normal
())
for
i
in
range
(
10
)
)
)
)
assert
np
.
all
(
xval
.
real
==
pytensor
.
function
([
x
],
real
(
x
))(
xval
))
assert
np
.
all
(
xval
.
real
==
pytensor
.
function
([
x
],
real
(
x
))(
xval
))
assert
np
.
all
(
xval
.
imag
==
pytensor
.
function
([
x
],
imag
(
x
))(
xval
))
assert
np
.
all
(
xval
.
imag
==
pytensor
.
function
([
x
],
imag
(
x
))(
xval
))
...
...
tests/tensor/test_elemwise.py
浏览文件 @
1da2891c
...
@@ -490,50 +490,50 @@ class TestCAReduce(unittest_tools.InferShapeTester):
...
@@ -490,50 +490,50 @@ class TestCAReduce(unittest_tools.InferShapeTester):
assert
len
(
axis2
)
==
len
(
tosum
)
assert
len
(
axis2
)
==
len
(
tosum
)
tosum
=
tuple
(
axis2
)
tosum
=
tuple
(
axis2
)
if
tensor_op
==
at_all
:
if
tensor_op
==
at_all
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
all
(
zv
,
axis
)
zv
=
np
.
all
(
zv
,
axis
)
if
len
(
tosum
)
==
0
:
if
len
(
tosum
)
==
0
:
zv
=
zv
!=
0
zv
=
zv
!=
0
elif
tensor_op
==
at_any
:
elif
tensor_op
==
at_any
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
any
(
zv
,
axis
)
zv
=
np
.
any
(
zv
,
axis
)
if
len
(
tosum
)
==
0
:
if
len
(
tosum
)
==
0
:
zv
=
zv
!=
0
zv
=
zv
!=
0
elif
scalar_op
==
aes
.
add
:
elif
scalar_op
==
aes
.
add
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
add
.
reduce
(
zv
,
axis
)
zv
=
np
.
add
.
reduce
(
zv
,
axis
)
if
dtype
==
"bool"
:
if
dtype
==
"bool"
:
# np.add of a bool upcast, while CAReduce don't
# np.add of a bool upcast, while CAReduce don't
zv
=
zv
.
astype
(
dtype
)
zv
=
zv
.
astype
(
dtype
)
elif
scalar_op
==
aes
.
mul
:
elif
scalar_op
==
aes
.
mul
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
multiply
.
reduce
(
zv
,
axis
)
zv
=
np
.
multiply
.
reduce
(
zv
,
axis
)
elif
scalar_op
==
aes
.
scalar_maximum
:
elif
scalar_op
==
aes
.
scalar_maximum
:
# There is no identity value for the maximum function
# There is no identity value for the maximum function
# So we can't support shape of dimensions 0.
# So we can't support shape of dimensions 0.
if
np
.
prod
(
zv
.
shape
)
==
0
:
if
np
.
prod
(
zv
.
shape
)
==
0
:
continue
continue
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
maximum
.
reduce
(
zv
,
axis
)
zv
=
np
.
maximum
.
reduce
(
zv
,
axis
)
elif
scalar_op
==
aes
.
scalar_minimum
:
elif
scalar_op
==
aes
.
scalar_minimum
:
# There is no identity value for the minimum function
# There is no identity value for the minimum function
# So we can't support shape of dimensions 0.
# So we can't support shape of dimensions 0.
if
np
.
prod
(
zv
.
shape
)
==
0
:
if
np
.
prod
(
zv
.
shape
)
==
0
:
continue
continue
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
minimum
.
reduce
(
zv
,
axis
)
zv
=
np
.
minimum
.
reduce
(
zv
,
axis
)
elif
scalar_op
==
aes
.
or_
:
elif
scalar_op
==
aes
.
or_
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
bitwise_or
.
reduce
(
zv
,
axis
)
zv
=
np
.
bitwise_or
.
reduce
(
zv
,
axis
)
elif
scalar_op
==
aes
.
and_
:
elif
scalar_op
==
aes
.
and_
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
reduce_bitwise_and
(
zv
,
axis
,
dtype
=
dtype
)
zv
=
reduce_bitwise_and
(
zv
,
axis
,
dtype
=
dtype
)
elif
scalar_op
==
aes
.
xor
:
elif
scalar_op
==
aes
.
xor
:
# There is no identity value for the xor function
# There is no identity value for the xor function
# So we can't support shape of dimensions 0.
# So we can't support shape of dimensions 0.
if
np
.
prod
(
zv
.
shape
)
==
0
:
if
np
.
prod
(
zv
.
shape
)
==
0
:
continue
continue
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
bitwise_xor
.
reduce
(
zv
,
axis
)
zv
=
np
.
bitwise_xor
.
reduce
(
zv
,
axis
)
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论