提交 1da2891c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add flake8-comprehensions plugin

上级 bc878138
...@@ -33,6 +33,8 @@ repos: ...@@ -33,6 +33,8 @@ repos:
rev: 6.0.0 rev: 6.0.0
hooks: hooks:
- id: flake8 - id: flake8
additional_dependencies:
- flake8-comprehensions
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.12.0 rev: 5.12.0
hooks: hooks:
......
...@@ -969,9 +969,7 @@ def inline_ofg_expansion(fgraph, node): ...@@ -969,9 +969,7 @@ def inline_ofg_expansion(fgraph, node):
return False return False
if not op.is_inline: if not op.is_inline:
return False return False
return clone_replace( return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
op.inner_outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)}
)
# We want to run this before the first merge optimizer # We want to run this before the first merge optimizer
......
...@@ -504,7 +504,7 @@ def grad( ...@@ -504,7 +504,7 @@ def grad(
if not isinstance(wrt, Sequence): if not isinstance(wrt, Sequence):
_wrt: List[Variable] = [wrt] _wrt: List[Variable] = [wrt]
else: else:
_wrt = [x for x in wrt] _wrt = list(wrt)
outputs = [] outputs = []
if cost is not None: if cost is not None:
...@@ -791,8 +791,8 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False): ...@@ -791,8 +791,8 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
pgrads = dict(zip(params, grads)) pgrads = dict(zip(params, grads))
# separate wrt from end grads: # separate wrt from end grads:
wrt_grads = list(pgrads[k] for k in wrt) wrt_grads = [pgrads[k] for k in wrt]
end_grads = list(pgrads[k] for k in end) end_grads = [pgrads[k] for k in end]
if details: if details:
return wrt_grads, end_grads, start_grads, cost_grads return wrt_grads, end_grads, start_grads, cost_grads
......
...@@ -1629,7 +1629,7 @@ def as_string( ...@@ -1629,7 +1629,7 @@ def as_string(
multi.add(op2) multi.add(op2)
else: else:
seen.add(input.owner) seen.add(input.owner)
multi_list = [x for x in multi] multi_list = list(multi)
done: Set = set() done: Set = set()
def multi_index(x): def multi_index(x):
......
...@@ -142,7 +142,7 @@ def graph_replace( ...@@ -142,7 +142,7 @@ def graph_replace(
raise ValueError(f"{key} is not a part of graph") raise ValueError(f"{key} is not a part of graph")
sorted_replacements = sorted( sorted_replacements = sorted(
tuple(fg_replace.items()), fg_replace.items(),
# sort based on the fg toposort, if a variable has no owner, it goes first # sort based on the fg toposort, if a variable has no owner, it goes first
key=partial(toposort_key, fg, toposort), key=partial(toposort_key, fg, toposort),
reverse=True, reverse=True,
......
...@@ -2575,8 +2575,8 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2575,8 +2575,8 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
for i in range(len(loop_timing)): for i in range(len(loop_timing)):
loop_times = "" loop_times = ""
if loop_process_count[i]: if loop_process_count[i]:
d = list( d = sorted(
reversed(sorted(loop_process_count[i].items(), key=lambda a: a[1])) loop_process_count[i].items(), key=lambda a: a[1], reverse=True
) )
loop_times = " ".join([str((str(k), v)) for k, v in d[:5]]) loop_times = " ".join([str((str(k), v)) for k, v in d[:5]])
if len(d) > 5: if len(d) > 5:
......
...@@ -633,11 +633,11 @@ class CLinker(Linker): ...@@ -633,11 +633,11 @@ class CLinker(Linker):
# The orphans field is listified to ensure a consistent order. # The orphans field is listified to ensure a consistent order.
# list(fgraph.orphans.difference(self.outputs)) # list(fgraph.orphans.difference(self.outputs))
self.orphans = list( self.orphans = [
r r
for r in self.variables for r in self.variables
if isinstance(r, AtomicVariable) and r not in self.inputs if isinstance(r, AtomicVariable) and r not in self.inputs
) ]
# C type constants (pytensor.scalar.ScalarType). They don't request an object # C type constants (pytensor.scalar.ScalarType). They don't request an object
self.consts = [] self.consts = []
# Move c type from orphans (pytensor.scalar.ScalarType) to self.consts # Move c type from orphans (pytensor.scalar.ScalarType) to self.consts
......
...@@ -810,7 +810,7 @@ class ParamsType(CType): ...@@ -810,7 +810,7 @@ class ParamsType(CType):
struct_extract_method=struct_extract_method, struct_extract_method=struct_extract_method,
) )
return list(sorted(list(c_support_code_set))) + [final_struct_code] return sorted(c_support_code_set) + [final_struct_code]
def c_code_cache_version(self): def c_code_cache_version(self):
return ((3,), tuple(t.c_code_cache_version() for t in self.types)) return ((3,), tuple(t.c_code_cache_version() for t in self.types))
......
...@@ -41,7 +41,7 @@ def jax_funcify_CAReduce(op, **kwargs): ...@@ -41,7 +41,7 @@ def jax_funcify_CAReduce(op, **kwargs):
elif scalar_op_name: elif scalar_op_name:
scalar_fn_name = scalar_op_name scalar_fn_name = scalar_op_name
to_reduce = reversed(sorted(axis)) to_reduce = sorted(axis, reverse=True)
if to_reduce: if to_reduce:
# In this case, we need to use the `jax.lax` function (if there # In this case, we need to use the `jax.lax` function (if there
......
...@@ -361,7 +361,7 @@ def create_multiaxis_reducer( ...@@ -361,7 +361,7 @@ def create_multiaxis_reducer(
careduce_fn_name = f"careduce_{scalar_op}" careduce_fn_name = f"careduce_{scalar_op}"
global_env = {} global_env = {}
to_reduce = reversed(sorted(axes)) to_reduce = sorted(axes, reverse=True)
careduce_lines_src = [] careduce_lines_src = []
var_name = input_name var_name = input_name
......
...@@ -796,7 +796,7 @@ class Print(Op): ...@@ -796,7 +796,7 @@ class Print(Op):
return output_gradients return output_gradients
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [x for x in eval_points] return list(eval_points)
def __setstate__(self, dct): def __setstate__(self, dct):
dct.setdefault("global_fn", _print_fn) dct.setdefault("global_fn", _print_fn)
......
...@@ -492,7 +492,7 @@ def scan( ...@@ -492,7 +492,7 @@ def scan(
# wrap sequences in a dictionary if they are not already dictionaries # wrap sequences in a dictionary if they are not already dictionaries
for i in range(n_seqs): for i in range(n_seqs):
if not isinstance(seqs[i], dict): if not isinstance(seqs[i], dict):
seqs[i] = dict([("input", seqs[i]), ("taps", [0])]) seqs[i] = {"input": seqs[i], "taps": [0]}
elif seqs[i].get("taps", None) is not None: elif seqs[i].get("taps", None) is not None:
seqs[i]["taps"] = wrap_into_list(seqs[i]["taps"]) seqs[i]["taps"] = wrap_into_list(seqs[i]["taps"])
elif seqs[i].get("taps", None) is None: elif seqs[i].get("taps", None) is None:
...@@ -504,7 +504,7 @@ def scan( ...@@ -504,7 +504,7 @@ def scan(
if outs_info[i] is not None: if outs_info[i] is not None:
if not isinstance(outs_info[i], dict): if not isinstance(outs_info[i], dict):
# by default any output has a tap value of -1 # by default any output has a tap value of -1
outs_info[i] = dict([("initial", outs_info[i]), ("taps", [-1])]) outs_info[i] = {"initial": outs_info[i], "taps": [-1]}
elif ( elif (
outs_info[i].get("initial", None) is None outs_info[i].get("initial", None) is None
and outs_info[i].get("taps", None) is not None and outs_info[i].get("taps", None) is not None
......
...@@ -1718,12 +1718,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1718,12 +1718,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
arg.shape[0] arg.shape[0]
for arg in inputs[self.seqs_arg_offset : self.shared_arg_offset] for arg in inputs[self.seqs_arg_offset : self.shared_arg_offset]
] ]
store_steps += [ store_steps += list(
arg inputs[self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot]
for arg in inputs[ )
self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot
]
]
# 2.1 Create storage space for outputs # 2.1 Create storage space for outputs
for idx in range(self.n_outs): for idx in range(self.n_outs):
...@@ -2270,7 +2267,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2270,7 +2267,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
offset = 1 + info.n_seqs offset = 1 + info.n_seqs
scan_outs = [x for x in input_shapes[offset : offset + n_outs]] scan_outs = list(input_shapes[offset : offset + n_outs])
offset += n_outs offset += n_outs
outs_shape_n = info.n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot outs_shape_n = info.n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot
for x in range(info.n_nit_sot): for x in range(info.n_nit_sot):
...@@ -2301,7 +2298,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2301,7 +2298,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
shp.append(v_shp_i[0]) shp.append(v_shp_i[0])
scan_outs.append(tuple(shp)) scan_outs.append(tuple(shp))
scan_outs += [x for x in input_shapes[offset : offset + info.n_shared_outs]] scan_outs += list(input_shapes[offset : offset + info.n_shared_outs])
# if we are dealing with a repeat-until, then we do not know the # if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i # leading dimension so we replace it for every entry with Shape_i
if info.as_while: if info.as_while:
......
...@@ -388,7 +388,7 @@ def push_out_non_seq_scan(fgraph, node): ...@@ -388,7 +388,7 @@ def push_out_non_seq_scan(fgraph, node):
if out in local_fgraph_outs_set: if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]] x = node.outputs[local_fgraph_outs_map[out]]
y = replace_with_out[idx] y = replace_with_out[idx]
y_shape = [shp for shp in y.shape] y_shape = list(y.shape)
replace_with[x] = at.alloc(y, node.inputs[0], *y_shape) replace_with[x] = at.alloc(y, node.inputs[0], *y_shape)
# We need to add one extra dimension to the outputs # We need to add one extra dimension to the outputs
......
...@@ -283,7 +283,7 @@ class RandomVariable(Op): ...@@ -283,7 +283,7 @@ class RandomVariable(Op):
shape = self._infer_shape(size, dist_params, param_shapes=param_shapes) shape = self._infer_shape(size, dist_params, param_shapes=param_shapes)
return [None, [s for s in shape]] return [None, list(shape)]
def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs): def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
res = super().__call__(rng, size, dtype, *args, **kwargs) res = super().__call__(rng, size, dtype, *args, **kwargs)
......
...@@ -1555,11 +1555,11 @@ def local_sum_prod_div_dimshuffle(fgraph, node): ...@@ -1555,11 +1555,11 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
) )
if len(compatible_dims) > 0: if len(compatible_dims) > 0:
optimized_dimshuffle_order = list( optimized_dimshuffle_order = [
ax ax
for i, ax in enumerate(dimshuffle_order) for i, ax in enumerate(dimshuffle_order)
if (i not in axis) or (ax != "x") if (i not in axis) or (ax != "x")
) ]
# Removing leading 'x' (since it will be done automatically) # Removing leading 'x' (since it will be done automatically)
while ( while (
...@@ -1644,7 +1644,7 @@ def local_op_of_op(fgraph, node): ...@@ -1644,7 +1644,7 @@ def local_op_of_op(fgraph, node):
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])] return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]
# figure out which axes were in the original sum # figure out which axes were in the original sum
newaxis = list(tuple(node_inps.owner.op.axis)) newaxis = list(node_inps.owner.op.axis)
for i in node.op.axis: for i in node.op.axis:
new_i = i new_i = i
for ii in node_inps.owner.op.axis: for ii in node_inps.owner.op.axis:
......
...@@ -810,7 +810,7 @@ def shape_padleft(t, n_ones=1): ...@@ -810,7 +810,7 @@ def shape_padleft(t, n_ones=1):
""" """
_t = at.as_tensor_variable(t) _t = at.as_tensor_variable(t)
pattern = ["x"] * n_ones + [i for i in range(_t.type.ndim)] pattern = ["x"] * n_ones + list(range(_t.type.ndim))
return _t.dimshuffle(pattern) return _t.dimshuffle(pattern)
...@@ -826,7 +826,7 @@ def shape_padright(t, n_ones=1): ...@@ -826,7 +826,7 @@ def shape_padright(t, n_ones=1):
""" """
_t = at.as_tensor_variable(t) _t = at.as_tensor_variable(t)
pattern = [i for i in range(_t.type.ndim)] + ["x"] * n_ones pattern = list(range(_t.type.ndim)) + ["x"] * n_ones
return _t.dimshuffle(pattern) return _t.dimshuffle(pattern)
...@@ -861,7 +861,7 @@ def shape_padaxis(t, axis): ...@@ -861,7 +861,7 @@ def shape_padaxis(t, axis):
if axis < 0: if axis < 0:
axis += ndim axis += ndim
pattern = [i for i in range(_t.type.ndim)] pattern = list(range(_t.type.ndim))
pattern.insert(axis, "x") pattern.insert(axis, "x")
return _t.dimshuffle(pattern) return _t.dimshuffle(pattern)
......
...@@ -2604,7 +2604,7 @@ class AdvancedSubtensor(Op): ...@@ -2604,7 +2604,7 @@ class AdvancedSubtensor(Op):
ishapes[0], index_shapes, indices_are_shapes=True ishapes[0], index_shapes, indices_are_shapes=True
) )
assert node.outputs[0].ndim == len(res_shape) assert node.outputs[0].ndim == len(res_shape)
return [[s for s in res_shape]] return [list(res_shape)]
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
(out,) = out_ (out,) = out_
......
[flake8] [flake8]
select = C,E,F,W select = C,E,F,W
ignore = E203,E231,E501,E741,W503,W504,C901 ignore = E203,E231,E501,E741,W503,W504,C408,C901
per-file-ignores = per-file-ignores =
**/__init__.py:F401,E402,F403 **/__init__.py:F401,E402,F403
pytensor/tensor/linalg.py:F401,F403 pytensor/tensor/linalg.py:F401,F403
......
...@@ -73,7 +73,7 @@ class TestNodeFinder: ...@@ -73,7 +73,7 @@ class TestNodeFinder:
assert hasattr(g, "get_nodes") assert hasattr(g, "get_nodes")
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)): for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if len([t for t in g.get_nodes(type)]) != num: if len(list(g.get_nodes(type))) != num:
raise Exception("Expected: %i times %s" % (num, type)) raise Exception("Expected: %i times %s" % (num, type))
new_e0 = add(y, z) new_e0 = add(y, z)
assert e0.owner in g.get_nodes(dot) assert e0.owner in g.get_nodes(dot)
...@@ -82,7 +82,7 @@ class TestNodeFinder: ...@@ -82,7 +82,7 @@ class TestNodeFinder:
assert e0.owner not in g.get_nodes(dot) assert e0.owner not in g.get_nodes(dot)
assert new_e0.owner in g.get_nodes(add) assert new_e0.owner in g.get_nodes(add)
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)): for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if len([t for t in g.get_nodes(type)]) != num: if len(list(g.get_nodes(type))) != num:
raise Exception("Expected: %i times %s" % (num, type)) raise Exception("Expected: %i times %s" % (num, type))
......
...@@ -87,7 +87,7 @@ class TestOp: ...@@ -87,7 +87,7 @@ class TestOp:
r1, r2 = MyType(1)(), MyType(2)() r1, r2 = MyType(1)(), MyType(2)()
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
# Are the inputs what I provided? # Are the inputs what I provided?
assert [x for x in node.inputs] == [r1, r2] assert list(node.inputs) == [r1, r2]
# Are the outputs what I expect? # Are the outputs what I expect?
assert [x.type for x in node.outputs] == [MyType(3)] assert [x.type for x in node.outputs] == [MyType(3)]
assert node.outputs[0].owner is node and node.outputs[0].index == 0 assert node.outputs[0].owner is node and node.outputs[0].index == 0
......
...@@ -1123,7 +1123,7 @@ class TestFusion: ...@@ -1123,7 +1123,7 @@ class TestFusion:
out = dot(x, y) + x + y + z out = dot(x, y) + x + y + z
f = function([x, y, z], out, mode=self.mode) f = function([x, y, z], out, mode=self.mode)
topo = [n for n in f.maker.fgraph.toposort()] topo = list(f.maker.fgraph.toposort())
assert len(topo) == 2 assert len(topo) == 2
assert topo[-1].op.inplace_pattern assert topo[-1].op.inplace_pattern
......
...@@ -3994,9 +3994,9 @@ class TestSigmoidUtils: ...@@ -3994,9 +3994,9 @@ class TestSigmoidUtils:
exp_op = exp exp_op = exp
assert is_1pexp(1 + exp_op(x), False) == (False, x) assert is_1pexp(1 + exp_op(x), False) == (False, x)
assert is_1pexp(exp_op(x) + 1, False) == (False, x) assert is_1pexp(exp_op(x) + 1, False) == (False, x)
for neg_, exp_arg in map( for neg_, exp_arg in (
lambda x: is_1pexp(x, only_process_constants=False), is_1pexp(x, only_process_constants=False)
[(1 + exp_op(-x)), (exp_op(-x) + 1)], for x in [(1 + exp_op(-x)), (exp_op(-x) + 1)]
): ):
assert not neg_ and is_same_graph(exp_arg, -x) assert not neg_ and is_same_graph(exp_arg, -x)
assert is_1pexp(1 - exp_op(x), False) is None assert is_1pexp(1 - exp_op(x), False) is None
......
...@@ -2004,7 +2004,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): ...@@ -2004,7 +2004,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y_val_fn = function( y_val_fn = function(
[x] + list(s), y, on_unused_input="ignore", mode=no_rewrites_mode [x] + list(s), y, on_unused_input="ignore", mode=no_rewrites_mode
) )
y_val = y_val_fn(*([x_val] + [s_ for s_ in s_val])) y_val = y_val_fn(*([x_val] + list(s_val)))
# This optimization should appear in the canonicalizations # This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False) y_opt = rewrite_graph(y, clone=False)
...@@ -2017,7 +2017,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): ...@@ -2017,7 +2017,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
assert isinstance(y_opt.owner.op, SpecifyShape) assert isinstance(y_opt.owner.op, SpecifyShape)
y_opt_fn = function([x] + list(s), y_opt, on_unused_input="ignore") y_opt_fn = function([x] + list(s), y_opt, on_unused_input="ignore")
y_opt_val = y_opt_fn(*([x_val] + [s_ for s_ in s_val])) y_opt_val = y_opt_fn(*([x_val] + list(s_val)))
assert np.allclose(y_val, y_opt_val) assert np.allclose(y_val, y_opt_val)
......
...@@ -2589,10 +2589,10 @@ TestBatchedDot = makeTester( ...@@ -2589,10 +2589,10 @@ TestBatchedDot = makeTester(
op=batched_dot, op=batched_dot,
expected=( expected=(
lambda xs, ys: np.asarray( lambda xs, ys: np.asarray(
list( [
x * y if x.ndim == 0 or y.ndim == 0 else np.dot(x, y) x * y if x.ndim == 0 or y.ndim == 0 else np.dot(x, y)
for x, y in zip(xs, ys) for x, y in zip(xs, ys)
), ],
dtype=aes.upcast(xs.dtype, ys.dtype), dtype=aes.upcast(xs.dtype, ys.dtype),
) )
), ),
...@@ -2694,7 +2694,7 @@ def test_batched_dot_not_contiguous(): ...@@ -2694,7 +2694,7 @@ def test_batched_dot_not_contiguous():
assert x.strides[0] == direction * np.dtype(config.floatX).itemsize assert x.strides[0] == direction * np.dtype(config.floatX).itemsize
assert not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]) assert not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"])
result = f(x, w) result = f(x, w)
ref_result = np.asarray(list(np.dot(u, v) for u, v in zip(x, w))) ref_result = np.asarray([np.dot(u, v) for u, v in zip(x, w)])
utt.assert_allclose(ref_result, result) utt.assert_allclose(ref_result, result)
for inverted in (0, 1): for inverted in (0, 1):
......
...@@ -15,9 +15,7 @@ class TestRealImag: ...@@ -15,9 +15,7 @@ class TestRealImag:
x = zvector() x = zvector()
rng = np.random.default_rng(23) rng = np.random.default_rng(23)
xval = np.asarray( xval = np.asarray(
list( [complex(rng.standard_normal(), rng.standard_normal()) for i in range(10)]
complex(rng.standard_normal(), rng.standard_normal()) for i in range(10)
)
) )
assert np.all(xval.real == pytensor.function([x], real(x))(xval)) assert np.all(xval.real == pytensor.function([x], real(x))(xval))
assert np.all(xval.imag == pytensor.function([x], imag(x))(xval)) assert np.all(xval.imag == pytensor.function([x], imag(x))(xval))
......
...@@ -490,50 +490,50 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -490,50 +490,50 @@ class TestCAReduce(unittest_tools.InferShapeTester):
assert len(axis2) == len(tosum) assert len(axis2) == len(tosum)
tosum = tuple(axis2) tosum = tuple(axis2)
if tensor_op == at_all: if tensor_op == at_all:
for axis in reversed(sorted(tosum)): for axis in sorted(tosum, reverse=True):
zv = np.all(zv, axis) zv = np.all(zv, axis)
if len(tosum) == 0: if len(tosum) == 0:
zv = zv != 0 zv = zv != 0
elif tensor_op == at_any: elif tensor_op == at_any:
for axis in reversed(sorted(tosum)): for axis in sorted(tosum, reverse=True):
zv = np.any(zv, axis) zv = np.any(zv, axis)
if len(tosum) == 0: if len(tosum) == 0:
zv = zv != 0 zv = zv != 0
elif scalar_op == aes.add: elif scalar_op == aes.add:
for axis in reversed(sorted(tosum)): for axis in sorted(tosum, reverse=True):
zv = np.add.reduce(zv, axis) zv = np.add.reduce(zv, axis)
if dtype == "bool": if dtype == "bool":
# np.add of a bool upcast, while CAReduce don't # np.add of a bool upcast, while CAReduce don't
zv = zv.astype(dtype) zv = zv.astype(dtype)
elif scalar_op == aes.mul: elif scalar_op == aes.mul:
for axis in reversed(sorted(tosum)): for axis in sorted(tosum, reverse=True):
zv = np.multiply.reduce(zv, axis) zv = np.multiply.reduce(zv, axis)
elif scalar_op == aes.scalar_maximum: elif scalar_op == aes.scalar_maximum:
# There is no identity value for the maximum function # There is no identity value for the maximum function
# So we can't support shape of dimensions 0. # So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0: if np.prod(zv.shape) == 0:
continue continue
for axis in reversed(sorted(tosum)): for axis in sorted(tosum, reverse=True):
zv = np.maximum.reduce(zv, axis) zv = np.maximum.reduce(zv, axis)
elif scalar_op == aes.scalar_minimum: elif scalar_op == aes.scalar_minimum:
# There is no identity value for the minimum function # There is no identity value for the minimum function
# So we can't support shape of dimensions 0. # So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0: if np.prod(zv.shape) == 0:
continue continue
for axis in reversed(sorted(tosum)): for axis in sorted(tosum, reverse=True):
zv = np.minimum.reduce(zv, axis) zv = np.minimum.reduce(zv, axis)
elif scalar_op == aes.or_: elif scalar_op == aes.or_:
for axis in reversed(sorted(tosum)): for axis in sorted(tosum, reverse=True):
zv = np.bitwise_or.reduce(zv, axis) zv = np.bitwise_or.reduce(zv, axis)
elif scalar_op == aes.and_: elif scalar_op == aes.and_:
for axis in reversed(sorted(tosum)): for axis in sorted(tosum, reverse=True):
zv = reduce_bitwise_and(zv, axis, dtype=dtype) zv = reduce_bitwise_and(zv, axis, dtype=dtype)
elif scalar_op == aes.xor: elif scalar_op == aes.xor:
# There is no identity value for the xor function # There is no identity value for the xor function
# So we can't support shape of dimensions 0. # So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0: if np.prod(zv.shape) == 0:
continue continue
for axis in reversed(sorted(tosum)): for axis in sorted(tosum, reverse=True):
zv = np.bitwise_xor.reduce(zv, axis) zv = np.bitwise_xor.reduce(zv, axis)
else: else:
raise NotImplementedError( raise NotImplementedError(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论