Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1da2891c
提交
1da2891c
authored
5月 14, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add flake8-comprehensions plugin
上级
bc878138
隐藏空白字符变更
内嵌
并排
正在显示
27 个修改的文件
包含
55 行增加
和
60 行删除
+55
-60
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-0
builders.py
pytensor/compile/builders.py
+1
-3
gradient.py
pytensor/gradient.py
+3
-3
basic.py
pytensor/graph/basic.py
+1
-1
replace.py
pytensor/graph/replace.py
+1
-1
basic.py
pytensor/graph/rewriting/basic.py
+2
-2
basic.py
pytensor/link/c/basic.py
+2
-2
params_type.py
pytensor/link/c/params_type.py
+1
-1
elemwise.py
pytensor/link/jax/dispatch/elemwise.py
+1
-1
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+1
-1
printing.py
pytensor/printing.py
+1
-1
basic.py
pytensor/scan/basic.py
+2
-2
op.py
pytensor/scan/op.py
+5
-8
rewriting.py
pytensor/scan/rewriting.py
+1
-1
op.py
pytensor/tensor/random/op.py
+1
-1
math.py
pytensor/tensor/rewriting/math.py
+3
-3
shape.py
pytensor/tensor/shape.py
+3
-3
subtensor.py
pytensor/tensor/subtensor.py
+1
-1
setup.cfg
setup.cfg
+1
-1
test_features.py
tests/graph/test_features.py
+2
-2
test_op.py
tests/graph/test_op.py
+1
-1
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+1
-1
test_math.py
tests/tensor/rewriting/test_math.py
+3
-3
test_subtensor.py
tests/tensor/rewriting/test_subtensor.py
+2
-2
test_blas.py
tests/tensor/test_blas.py
+3
-3
test_complex.py
tests/tensor/test_complex.py
+1
-3
test_elemwise.py
tests/tensor/test_elemwise.py
+9
-9
没有找到文件。
.pre-commit-config.yaml
浏览文件 @
1da2891c
...
...
@@ -33,6 +33,8 @@ repos:
rev
:
6.0.0
hooks
:
-
id
:
flake8
additional_dependencies
:
-
flake8-comprehensions
-
repo
:
https://github.com/pycqa/isort
rev
:
5.12.0
hooks
:
...
...
pytensor/compile/builders.py
浏览文件 @
1da2891c
...
...
@@ -969,9 +969,7 @@ def inline_ofg_expansion(fgraph, node):
return
False
if
not
op
.
is_inline
:
return
False
return
clone_replace
(
op
.
inner_outputs
,
{
u
:
v
for
u
,
v
in
zip
(
op
.
inner_inputs
,
node
.
inputs
)}
)
return
clone_replace
(
op
.
inner_outputs
,
dict
(
zip
(
op
.
inner_inputs
,
node
.
inputs
)))
# We want to run this before the first merge optimizer
...
...
pytensor/gradient.py
浏览文件 @
1da2891c
...
...
@@ -504,7 +504,7 @@ def grad(
if
not
isinstance
(
wrt
,
Sequence
):
_wrt
:
List
[
Variable
]
=
[
wrt
]
else
:
_wrt
=
[
x
for
x
in
wrt
]
_wrt
=
list
(
wrt
)
outputs
=
[]
if
cost
is
not
None
:
...
...
@@ -791,8 +791,8 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
pgrads
=
dict
(
zip
(
params
,
grads
))
# separate wrt from end grads:
wrt_grads
=
list
(
pgrads
[
k
]
for
k
in
wrt
)
end_grads
=
list
(
pgrads
[
k
]
for
k
in
end
)
wrt_grads
=
[
pgrads
[
k
]
for
k
in
wrt
]
end_grads
=
[
pgrads
[
k
]
for
k
in
end
]
if
details
:
return
wrt_grads
,
end_grads
,
start_grads
,
cost_grads
...
...
pytensor/graph/basic.py
浏览文件 @
1da2891c
...
...
@@ -1629,7 +1629,7 @@ def as_string(
multi
.
add
(
op2
)
else
:
seen
.
add
(
input
.
owner
)
multi_list
=
[
x
for
x
in
multi
]
multi_list
=
list
(
multi
)
done
:
Set
=
set
()
def
multi_index
(
x
):
...
...
pytensor/graph/replace.py
浏览文件 @
1da2891c
...
...
@@ -142,7 +142,7 @@ def graph_replace(
raise
ValueError
(
f
"{key} is not a part of graph"
)
sorted_replacements
=
sorted
(
tuple
(
fg_replace
.
items
()
),
fg_replace
.
items
(
),
# sort based on the fg toposort, if a variable has no owner, it goes first
key
=
partial
(
toposort_key
,
fg
,
toposort
),
reverse
=
True
,
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
1da2891c
...
...
@@ -2575,8 +2575,8 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
for
i
in
range
(
len
(
loop_timing
)):
loop_times
=
""
if
loop_process_count
[
i
]:
d
=
list
(
reversed
(
sorted
(
loop_process_count
[
i
]
.
items
(),
key
=
lambda
a
:
a
[
1
]))
d
=
sorted
(
loop_process_count
[
i
]
.
items
(),
key
=
lambda
a
:
a
[
1
],
reverse
=
True
)
loop_times
=
" "
.
join
([
str
((
str
(
k
),
v
))
for
k
,
v
in
d
[:
5
]])
if
len
(
d
)
>
5
:
...
...
pytensor/link/c/basic.py
浏览文件 @
1da2891c
...
...
@@ -633,11 +633,11 @@ class CLinker(Linker):
# The orphans field is listified to ensure a consistent order.
# list(fgraph.orphans.difference(self.outputs))
self
.
orphans
=
list
(
self
.
orphans
=
[
r
for
r
in
self
.
variables
if
isinstance
(
r
,
AtomicVariable
)
and
r
not
in
self
.
inputs
)
]
# C type constants (pytensor.scalar.ScalarType). They don't request an object
self
.
consts
=
[]
# Move c type from orphans (pytensor.scalar.ScalarType) to self.consts
...
...
pytensor/link/c/params_type.py
浏览文件 @
1da2891c
...
...
@@ -810,7 +810,7 @@ class ParamsType(CType):
struct_extract_method
=
struct_extract_method
,
)
return
list
(
sorted
(
list
(
c_support_code_set
))
)
+
[
final_struct_code
]
return
sorted
(
c_support_code_set
)
+
[
final_struct_code
]
def
c_code_cache_version
(
self
):
return
((
3
,),
tuple
(
t
.
c_code_cache_version
()
for
t
in
self
.
types
))
...
...
pytensor/link/jax/dispatch/elemwise.py
浏览文件 @
1da2891c
...
...
@@ -41,7 +41,7 @@ def jax_funcify_CAReduce(op, **kwargs):
elif
scalar_op_name
:
scalar_fn_name
=
scalar_op_name
to_reduce
=
reversed
(
sorted
(
axis
)
)
to_reduce
=
sorted
(
axis
,
reverse
=
True
)
if
to_reduce
:
# In this case, we need to use the `jax.lax` function (if there
...
...
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
1da2891c
...
...
@@ -361,7 +361,7 @@ def create_multiaxis_reducer(
careduce_fn_name
=
f
"careduce_{scalar_op}"
global_env
=
{}
to_reduce
=
reversed
(
sorted
(
axes
)
)
to_reduce
=
sorted
(
axes
,
reverse
=
True
)
careduce_lines_src
=
[]
var_name
=
input_name
...
...
pytensor/printing.py
浏览文件 @
1da2891c
...
...
@@ -796,7 +796,7 @@ class Print(Op):
return
output_gradients
def
R_op
(
self
,
inputs
,
eval_points
):
return
[
x
for
x
in
eval_points
]
return
list
(
eval_points
)
def
__setstate__
(
self
,
dct
):
dct
.
setdefault
(
"global_fn"
,
_print_fn
)
...
...
pytensor/scan/basic.py
浏览文件 @
1da2891c
...
...
@@ -492,7 +492,7 @@ def scan(
# wrap sequences in a dictionary if they are not already dictionaries
for
i
in
range
(
n_seqs
):
if
not
isinstance
(
seqs
[
i
],
dict
):
seqs
[
i
]
=
dict
([(
"input"
,
seqs
[
i
]),
(
"taps"
,
[
0
])])
seqs
[
i
]
=
{
"input"
:
seqs
[
i
],
"taps"
:
[
0
]}
elif
seqs
[
i
]
.
get
(
"taps"
,
None
)
is
not
None
:
seqs
[
i
][
"taps"
]
=
wrap_into_list
(
seqs
[
i
][
"taps"
])
elif
seqs
[
i
]
.
get
(
"taps"
,
None
)
is
None
:
...
...
@@ -504,7 +504,7 @@ def scan(
if
outs_info
[
i
]
is
not
None
:
if
not
isinstance
(
outs_info
[
i
],
dict
):
# by default any output has a tap value of -1
outs_info
[
i
]
=
dict
([(
"initial"
,
outs_info
[
i
]),
(
"taps"
,
[
-
1
])])
outs_info
[
i
]
=
{
"initial"
:
outs_info
[
i
],
"taps"
:
[
-
1
]}
elif
(
outs_info
[
i
]
.
get
(
"initial"
,
None
)
is
None
and
outs_info
[
i
]
.
get
(
"taps"
,
None
)
is
not
None
...
...
pytensor/scan/op.py
浏览文件 @
1da2891c
...
...
@@ -1718,12 +1718,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
arg
.
shape
[
0
]
for
arg
in
inputs
[
self
.
seqs_arg_offset
:
self
.
shared_arg_offset
]
]
store_steps
+=
[
arg
for
arg
in
inputs
[
self
.
nit_sot_arg_offset
:
self
.
nit_sot_arg_offset
+
info
.
n_nit_sot
]
]
store_steps
+=
list
(
inputs
[
self
.
nit_sot_arg_offset
:
self
.
nit_sot_arg_offset
+
info
.
n_nit_sot
]
)
# 2.1 Create storage space for outputs
for
idx
in
range
(
self
.
n_outs
):
...
...
@@ -2270,7 +2267,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
offset
=
1
+
info
.
n_seqs
scan_outs
=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
n_outs
]]
scan_outs
=
list
(
input_shapes
[
offset
:
offset
+
n_outs
])
offset
+=
n_outs
outs_shape_n
=
info
.
n_mit_mot_outs
+
info
.
n_mit_sot
+
info
.
n_sit_sot
for
x
in
range
(
info
.
n_nit_sot
):
...
...
@@ -2301,7 +2298,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
shp
.
append
(
v_shp_i
[
0
])
scan_outs
.
append
(
tuple
(
shp
))
scan_outs
+=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
info
.
n_shared_outs
]]
scan_outs
+=
list
(
input_shapes
[
offset
:
offset
+
info
.
n_shared_outs
])
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
if
info
.
as_while
:
...
...
pytensor/scan/rewriting.py
浏览文件 @
1da2891c
...
...
@@ -388,7 +388,7 @@ def push_out_non_seq_scan(fgraph, node):
if
out
in
local_fgraph_outs_set
:
x
=
node
.
outputs
[
local_fgraph_outs_map
[
out
]]
y
=
replace_with_out
[
idx
]
y_shape
=
[
shp
for
shp
in
y
.
shape
]
y_shape
=
list
(
y
.
shape
)
replace_with
[
x
]
=
at
.
alloc
(
y
,
node
.
inputs
[
0
],
*
y_shape
)
# We need to add one extra dimension to the outputs
...
...
pytensor/tensor/random/op.py
浏览文件 @
1da2891c
...
...
@@ -283,7 +283,7 @@ class RandomVariable(Op):
shape
=
self
.
_infer_shape
(
size
,
dist_params
,
param_shapes
=
param_shapes
)
return
[
None
,
[
s
for
s
in
shape
]
]
return
[
None
,
list
(
shape
)
]
def
__call__
(
self
,
*
args
,
size
=
None
,
name
=
None
,
rng
=
None
,
dtype
=
None
,
**
kwargs
):
res
=
super
()
.
__call__
(
rng
,
size
,
dtype
,
*
args
,
**
kwargs
)
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
1da2891c
...
...
@@ -1555,11 +1555,11 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
)
if
len
(
compatible_dims
)
>
0
:
optimized_dimshuffle_order
=
list
(
optimized_dimshuffle_order
=
[
ax
for
i
,
ax
in
enumerate
(
dimshuffle_order
)
if
(
i
not
in
axis
)
or
(
ax
!=
"x"
)
)
]
# Removing leading 'x' (since it will be done automatically)
while
(
...
...
@@ -1644,7 +1644,7 @@ def local_op_of_op(fgraph, node):
return
[
op_type
(
None
,
dtype
=
out_dtype
)(
node_inps
.
owner
.
inputs
[
0
])]
# figure out which axes were in the original sum
newaxis
=
list
(
tuple
(
node_inps
.
owner
.
op
.
axis
)
)
newaxis
=
list
(
node_inps
.
owner
.
op
.
axis
)
for
i
in
node
.
op
.
axis
:
new_i
=
i
for
ii
in
node_inps
.
owner
.
op
.
axis
:
...
...
pytensor/tensor/shape.py
浏览文件 @
1da2891c
...
...
@@ -810,7 +810,7 @@ def shape_padleft(t, n_ones=1):
"""
_t
=
at
.
as_tensor_variable
(
t
)
pattern
=
[
"x"
]
*
n_ones
+
[
i
for
i
in
range
(
_t
.
type
.
ndim
)]
pattern
=
[
"x"
]
*
n_ones
+
list
(
range
(
_t
.
type
.
ndim
))
return
_t
.
dimshuffle
(
pattern
)
...
...
@@ -826,7 +826,7 @@ def shape_padright(t, n_ones=1):
"""
_t
=
at
.
as_tensor_variable
(
t
)
pattern
=
[
i
for
i
in
range
(
_t
.
type
.
ndim
)]
+
[
"x"
]
*
n_ones
pattern
=
list
(
range
(
_t
.
type
.
ndim
))
+
[
"x"
]
*
n_ones
return
_t
.
dimshuffle
(
pattern
)
...
...
@@ -861,7 +861,7 @@ def shape_padaxis(t, axis):
if
axis
<
0
:
axis
+=
ndim
pattern
=
[
i
for
i
in
range
(
_t
.
type
.
ndim
)]
pattern
=
list
(
range
(
_t
.
type
.
ndim
))
pattern
.
insert
(
axis
,
"x"
)
return
_t
.
dimshuffle
(
pattern
)
...
...
pytensor/tensor/subtensor.py
浏览文件 @
1da2891c
...
...
@@ -2604,7 +2604,7 @@ class AdvancedSubtensor(Op):
ishapes
[
0
],
index_shapes
,
indices_are_shapes
=
True
)
assert
node
.
outputs
[
0
]
.
ndim
==
len
(
res_shape
)
return
[
[
s
for
s
in
res_shape
]
]
return
[
list
(
res_shape
)
]
def
perform
(
self
,
node
,
inputs
,
out_
):
(
out
,)
=
out_
...
...
setup.cfg
浏览文件 @
1da2891c
[flake8]
select = C,E,F,W
ignore = E203,E231,E501,E741,W503,W504,C901
ignore = E203,E231,E501,E741,W503,W504,C
408,C
901
per-file-ignores =
**/__init__.py:F401,E402,F403
pytensor/tensor/linalg.py:F401,F403
...
...
tests/graph/test_features.py
浏览文件 @
1da2891c
...
...
@@ -73,7 +73,7 @@ class TestNodeFinder:
assert
hasattr
(
g
,
"get_nodes"
)
for
type
,
num
in
((
add
,
3
),
(
sigmoid
,
3
),
(
dot
,
2
)):
if
len
(
[
t
for
t
in
g
.
get_nodes
(
type
)]
)
!=
num
:
if
len
(
list
(
g
.
get_nodes
(
type
))
)
!=
num
:
raise
Exception
(
"Expected:
%
i times
%
s"
%
(
num
,
type
))
new_e0
=
add
(
y
,
z
)
assert
e0
.
owner
in
g
.
get_nodes
(
dot
)
...
...
@@ -82,7 +82,7 @@ class TestNodeFinder:
assert
e0
.
owner
not
in
g
.
get_nodes
(
dot
)
assert
new_e0
.
owner
in
g
.
get_nodes
(
add
)
for
type
,
num
in
((
add
,
4
),
(
sigmoid
,
3
),
(
dot
,
1
)):
if
len
(
[
t
for
t
in
g
.
get_nodes
(
type
)]
)
!=
num
:
if
len
(
list
(
g
.
get_nodes
(
type
))
)
!=
num
:
raise
Exception
(
"Expected:
%
i times
%
s"
%
(
num
,
type
))
...
...
tests/graph/test_op.py
浏览文件 @
1da2891c
...
...
@@ -87,7 +87,7 @@ class TestOp:
r1
,
r2
=
MyType
(
1
)(),
MyType
(
2
)()
node
=
MyOp
.
make_node
(
r1
,
r2
)
# Are the inputs what I provided?
assert
[
x
for
x
in
node
.
inputs
]
==
[
r1
,
r2
]
assert
list
(
node
.
inputs
)
==
[
r1
,
r2
]
# Are the outputs what I expect?
assert
[
x
.
type
for
x
in
node
.
outputs
]
==
[
MyType
(
3
)]
assert
node
.
outputs
[
0
]
.
owner
is
node
and
node
.
outputs
[
0
]
.
index
==
0
...
...
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
1da2891c
...
...
@@ -1123,7 +1123,7 @@ class TestFusion:
out
=
dot
(
x
,
y
)
+
x
+
y
+
z
f
=
function
([
x
,
y
,
z
],
out
,
mode
=
self
.
mode
)
topo
=
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()]
topo
=
list
(
f
.
maker
.
fgraph
.
toposort
())
assert
len
(
topo
)
==
2
assert
topo
[
-
1
]
.
op
.
inplace_pattern
...
...
tests/tensor/rewriting/test_math.py
浏览文件 @
1da2891c
...
...
@@ -3994,9 +3994,9 @@ class TestSigmoidUtils:
exp_op
=
exp
assert
is_1pexp
(
1
+
exp_op
(
x
),
False
)
==
(
False
,
x
)
assert
is_1pexp
(
exp_op
(
x
)
+
1
,
False
)
==
(
False
,
x
)
for
neg_
,
exp_arg
in
map
(
lambda
x
:
is_1pexp
(
x
,
only_process_constants
=
False
),
[(
1
+
exp_op
(
-
x
)),
(
exp_op
(
-
x
)
+
1
)],
for
neg_
,
exp_arg
in
(
is_1pexp
(
x
,
only_process_constants
=
False
)
for
x
in
[(
1
+
exp_op
(
-
x
)),
(
exp_op
(
-
x
)
+
1
)]
):
assert
not
neg_
and
is_same_graph
(
exp_arg
,
-
x
)
assert
is_1pexp
(
1
-
exp_op
(
x
),
False
)
is
None
...
...
tests/tensor/rewriting/test_subtensor.py
浏览文件 @
1da2891c
...
...
@@ -2004,7 +2004,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y_val_fn
=
function
(
[
x
]
+
list
(
s
),
y
,
on_unused_input
=
"ignore"
,
mode
=
no_rewrites_mode
)
y_val
=
y_val_fn
(
*
([
x_val
]
+
[
s_
for
s_
in
s_val
]
))
y_val
=
y_val_fn
(
*
([
x_val
]
+
list
(
s_val
)
))
# This optimization should appear in the canonicalizations
y_opt
=
rewrite_graph
(
y
,
clone
=
False
)
...
...
@@ -2017,7 +2017,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
assert
isinstance
(
y_opt
.
owner
.
op
,
SpecifyShape
)
y_opt_fn
=
function
([
x
]
+
list
(
s
),
y_opt
,
on_unused_input
=
"ignore"
)
y_opt_val
=
y_opt_fn
(
*
([
x_val
]
+
[
s_
for
s_
in
s_val
]
))
y_opt_val
=
y_opt_fn
(
*
([
x_val
]
+
list
(
s_val
)
))
assert
np
.
allclose
(
y_val
,
y_opt_val
)
...
...
tests/tensor/test_blas.py
浏览文件 @
1da2891c
...
...
@@ -2589,10 +2589,10 @@ TestBatchedDot = makeTester(
op
=
batched_dot
,
expected
=
(
lambda
xs
,
ys
:
np
.
asarray
(
list
(
[
x
*
y
if
x
.
ndim
==
0
or
y
.
ndim
==
0
else
np
.
dot
(
x
,
y
)
for
x
,
y
in
zip
(
xs
,
ys
)
)
,
]
,
dtype
=
aes
.
upcast
(
xs
.
dtype
,
ys
.
dtype
),
)
),
...
...
@@ -2694,7 +2694,7 @@ def test_batched_dot_not_contiguous():
assert
x
.
strides
[
0
]
==
direction
*
np
.
dtype
(
config
.
floatX
)
.
itemsize
assert
not
(
x
.
flags
[
"C_CONTIGUOUS"
]
or
x
.
flags
[
"F_CONTIGUOUS"
])
result
=
f
(
x
,
w
)
ref_result
=
np
.
asarray
(
list
(
np
.
dot
(
u
,
v
)
for
u
,
v
in
zip
(
x
,
w
))
)
ref_result
=
np
.
asarray
(
[
np
.
dot
(
u
,
v
)
for
u
,
v
in
zip
(
x
,
w
)]
)
utt
.
assert_allclose
(
ref_result
,
result
)
for
inverted
in
(
0
,
1
):
...
...
tests/tensor/test_complex.py
浏览文件 @
1da2891c
...
...
@@ -15,9 +15,7 @@ class TestRealImag:
x
=
zvector
()
rng
=
np
.
random
.
default_rng
(
23
)
xval
=
np
.
asarray
(
list
(
complex
(
rng
.
standard_normal
(),
rng
.
standard_normal
())
for
i
in
range
(
10
)
)
[
complex
(
rng
.
standard_normal
(),
rng
.
standard_normal
())
for
i
in
range
(
10
)]
)
assert
np
.
all
(
xval
.
real
==
pytensor
.
function
([
x
],
real
(
x
))(
xval
))
assert
np
.
all
(
xval
.
imag
==
pytensor
.
function
([
x
],
imag
(
x
))(
xval
))
...
...
tests/tensor/test_elemwise.py
浏览文件 @
1da2891c
...
...
@@ -490,50 +490,50 @@ class TestCAReduce(unittest_tools.InferShapeTester):
assert
len
(
axis2
)
==
len
(
tosum
)
tosum
=
tuple
(
axis2
)
if
tensor_op
==
at_all
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
all
(
zv
,
axis
)
if
len
(
tosum
)
==
0
:
zv
=
zv
!=
0
elif
tensor_op
==
at_any
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
any
(
zv
,
axis
)
if
len
(
tosum
)
==
0
:
zv
=
zv
!=
0
elif
scalar_op
==
aes
.
add
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
add
.
reduce
(
zv
,
axis
)
if
dtype
==
"bool"
:
# np.add of a bool upcast, while CAReduce don't
zv
=
zv
.
astype
(
dtype
)
elif
scalar_op
==
aes
.
mul
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
multiply
.
reduce
(
zv
,
axis
)
elif
scalar_op
==
aes
.
scalar_maximum
:
# There is no identity value for the maximum function
# So we can't support shape of dimensions 0.
if
np
.
prod
(
zv
.
shape
)
==
0
:
continue
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
maximum
.
reduce
(
zv
,
axis
)
elif
scalar_op
==
aes
.
scalar_minimum
:
# There is no identity value for the minimum function
# So we can't support shape of dimensions 0.
if
np
.
prod
(
zv
.
shape
)
==
0
:
continue
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
minimum
.
reduce
(
zv
,
axis
)
elif
scalar_op
==
aes
.
or_
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
bitwise_or
.
reduce
(
zv
,
axis
)
elif
scalar_op
==
aes
.
and_
:
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
reduce_bitwise_and
(
zv
,
axis
,
dtype
=
dtype
)
elif
scalar_op
==
aes
.
xor
:
# There is no identity value for the xor function
# So we can't support shape of dimensions 0.
if
np
.
prod
(
zv
.
shape
)
==
0
:
continue
for
axis
in
reversed
(
sorted
(
tosum
)
):
for
axis
in
sorted
(
tosum
,
reverse
=
True
):
zv
=
np
.
bitwise_xor
.
reduce
(
zv
,
axis
)
else
:
raise
NotImplementedError
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论