提交 6f685799 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unnecessary dicts from keyword specifications

上级 fec0fe85
...@@ -149,7 +149,7 @@ class IfElse(_NoPythonOp): ...@@ -149,7 +149,7 @@ class IfElse(_NoPythonOp):
new_outs = new_ifelse( new_outs = new_ifelse(
node.inputs[0], node.inputs[0],
*(new_ts_inputs + new_fs_inputs), *(new_ts_inputs + new_fs_inputs),
**dict(return_list=True), return_list=True,
) )
else: else:
new_outs = [] new_outs = []
...@@ -203,7 +203,7 @@ class IfElse(_NoPythonOp): ...@@ -203,7 +203,7 @@ class IfElse(_NoPythonOp):
return Apply(self, [c] + list(args), [t.type() for t in aes]) return Apply(self, [c] + list(args), [t.type() for t in aes])
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return self(inputs[0], *eval_points[1:], **dict(return_list=True)) return self(inputs[0], *eval_points[1:], return_list=True)
def grad(self, ins, grads): def grad(self, ins, grads):
aes = ins[1:][: self.n_outs] aes = ins[1:][: self.n_outs]
...@@ -244,8 +244,8 @@ class IfElse(_NoPythonOp): ...@@ -244,8 +244,8 @@ class IfElse(_NoPythonOp):
condition_grad = condition.zeros_like().astype(config.floatX) condition_grad = condition.zeros_like().astype(config.floatX)
return ( return (
[condition_grad] [condition_grad]
+ if_true_op(*if_true, **dict(return_list=True)) + if_true_op(*if_true, return_list=True)
+ if_false_op(*if_false, **dict(return_list=True)) + if_false_op(*if_false, return_list=True)
) )
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
...@@ -407,7 +407,7 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -407,7 +407,7 @@ def ifelse(condition, then_branch, else_branch, name=None):
new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, gpu=False, name=name) new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, gpu=False, name=name)
ins = [condition] + list(new_then_branch) + list(new_else_branch) ins = [condition] + list(new_then_branch) + list(new_else_branch)
rval = new_ifelse(*ins, **dict(return_list=True)) rval = new_ifelse(*ins, return_list=True)
if rval_type is None: if rval_type is None:
return rval[0] return rval[0]
...@@ -432,7 +432,7 @@ def cond_make_inplace(fgraph, node): ...@@ -432,7 +432,7 @@ def cond_make_inplace(fgraph, node):
) )
): ):
return IfElse(n_outs=op.n_outs, as_view=True, gpu=op.gpu, name=op.name)( return IfElse(n_outs=op.n_outs, as_view=True, gpu=op.gpu, name=op.name)(
*node.inputs, **dict(return_list=True) *node.inputs, return_list=True
) )
return False return False
...@@ -533,8 +533,8 @@ def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node): ...@@ -533,8 +533,8 @@ def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node):
else: else:
true_ins.append(x) true_ins.append(x)
false_ins.append(x) false_ins.append(x)
true_eval = mop(*true_ins, **dict(return_list=True)) true_eval = mop(*true_ins, return_list=True)
false_eval = mop(*false_ins, **dict(return_list=True)) false_eval = mop(*false_ins, return_list=True)
# true_eval = clone_replace(outs, replace = dict(zip(node.outputs, aes))) # true_eval = clone_replace(outs, replace = dict(zip(node.outputs, aes)))
# false_eval = clone_replace(outs, replace = dict(zip(node.outputs, fs))) # false_eval = clone_replace(outs, replace = dict(zip(node.outputs, fs)))
...@@ -566,7 +566,7 @@ def cond_merge_ifs_true(fgraph, node): ...@@ -566,7 +566,7 @@ def cond_merge_ifs_true(fgraph, node):
old_ins = list(node.inputs) old_ins = list(node.inputs)
for pos, var in replace.items(): for pos, var in replace.items():
old_ins[pos] = var old_ins[pos] = var
return op(*old_ins, **dict(return_list=True)) return op(*old_ins, return_list=True)
@local_optimizer([IfElse]) @local_optimizer([IfElse])
...@@ -593,7 +593,7 @@ def cond_merge_ifs_false(fgraph, node): ...@@ -593,7 +593,7 @@ def cond_merge_ifs_false(fgraph, node):
old_ins = list(node.inputs) old_ins = list(node.inputs)
for pos, var in replace.items(): for pos, var in replace.items():
old_ins[pos] = var old_ins[pos] = var
return op(*old_ins, **dict(return_list=True)) return op(*old_ins, return_list=True)
class CondMerge(GlobalOptimizer): class CondMerge(GlobalOptimizer):
...@@ -635,7 +635,7 @@ class CondMerge(GlobalOptimizer): ...@@ -635,7 +635,7 @@ class CondMerge(GlobalOptimizer):
name=mn_name + "&" + pl_name, name=mn_name + "&" + pl_name,
) )
print("here") print("here")
new_outs = new_ifelse(*new_ins, **dict(return_list=True)) new_outs = new_ifelse(*new_ins, return_list=True)
new_outs = [clone_replace(x) for x in new_outs] new_outs = [clone_replace(x) for x in new_outs]
old_outs = [] old_outs = []
if type(merging_node.outputs) not in (list, tuple): if type(merging_node.outputs) not in (list, tuple):
...@@ -684,7 +684,7 @@ def cond_remove_identical(fgraph, node): ...@@ -684,7 +684,7 @@ def cond_remove_identical(fgraph, node):
new_ifelse = IfElse(n_outs=len(nw_ts), as_view=op.as_view, gpu=op.gpu, name=op.name) new_ifelse = IfElse(n_outs=len(nw_ts), as_view=op.as_view, gpu=op.gpu, name=op.name)
new_ins = [node.inputs[0]] + nw_ts + nw_fs new_ins = [node.inputs[0]] + nw_ts + nw_fs
new_outs = new_ifelse(*new_ins, **dict(return_list=True)) new_outs = new_ifelse(*new_ins, return_list=True)
rval = [] rval = []
for idx in range(len(node.outputs)): for idx in range(len(node.outputs)):
...@@ -736,7 +736,7 @@ def cond_merge_random_op(fgraph, main_node): ...@@ -736,7 +736,7 @@ def cond_merge_random_op(fgraph, main_node):
gpu=False, gpu=False,
name=mn_name + "&" + pl_name, name=mn_name + "&" + pl_name,
) )
new_outs = new_ifelse(*new_ins, **dict(return_list=True)) new_outs = new_ifelse(*new_ins, return_list=True)
old_outs = [] old_outs = []
if type(merging_node.outputs) not in (list, tuple): if type(merging_node.outputs) not in (list, tuple):
old_outs += [merging_node.outputs] old_outs += [merging_node.outputs]
......
...@@ -218,7 +218,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -218,7 +218,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
op_outs = clone_replace(op_outs, replace=givens) op_outs = clone_replace(op_outs, replace=givens)
nw_info = dataclasses.replace(op.info, n_seqs=nw_n_seqs) nw_info = dataclasses.replace(op.info, n_seqs=nw_n_seqs)
nwScan = Scan(nw_inner, op_outs, nw_info, op.mode) nwScan = Scan(nw_inner, op_outs, nw_info, op.mode)
nw_outs = nwScan(*nw_outer, **dict(return_list=True)) nw_outs = nwScan(*nw_outer, return_list=True)
return dict([("remove", [node])] + list(zip(node.outputs, nw_outs))) return dict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
else: else:
return False return False
...@@ -399,9 +399,7 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -399,9 +399,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
nwScan = Scan(op_ins, op_outs, op.info, op.mode) nwScan = Scan(op_ins, op_outs, op.info, op.mode)
# Do not call make_node for test_value # Do not call make_node for test_value
nw_node = nwScan(*(node.inputs + nw_outer), **dict(return_list=True))[ nw_node = nwScan(*(node.inputs + nw_outer), return_list=True)[0].owner
0
].owner
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
list(zip(node.outputs, nw_node.outputs)), list(zip(node.outputs, nw_node.outputs)),
...@@ -672,7 +670,7 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -672,7 +670,7 @@ class PushOutSeqScan(GlobalOptimizer):
# Do not call make_node for test_value # Do not call make_node for test_value
nw_node = nwScan( nw_node = nwScan(
*(node.inputs[:1] + nw_outer + node.inputs[1:]), *(node.inputs[:1] + nw_outer + node.inputs[1:]),
**dict(return_list=True), return_list=True,
)[0].owner )[0].owner
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
...@@ -958,9 +956,9 @@ class PushOutScanOutput(GlobalOptimizer): ...@@ -958,9 +956,9 @@ class PushOutScanOutput(GlobalOptimizer):
) )
# Create the Apply node for the scan op # Create the Apply node for the scan op
new_scan_node = new_scan_op( new_scan_node = new_scan_op(*new_scan_args.outer_inputs, return_list=True)[
*new_scan_args.outer_inputs, **dict(return_list=True) 0
)[0].owner ].owner
# Modify the outer graph to make sure the outputs of the new scan are # Modify the outer graph to make sure the outputs of the new scan are
# used instead of the outputs of the old scan # used instead of the outputs of the old scan
...@@ -1071,7 +1069,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1071,7 +1069,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
new_op.destroy_map = destroy_map new_op.destroy_map = destroy_map
# Do not call make_node for test_value # Do not call make_node for test_value
new_outs = new_op(*inputs, **dict(return_list=True)) new_outs = new_op(*inputs, return_list=True)
try: try:
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
list(zip(node.outputs, new_outs)), list(zip(node.outputs, new_outs)),
...@@ -1595,9 +1593,7 @@ class ScanSaveMem(GlobalOptimizer): ...@@ -1595,9 +1593,7 @@ class ScanSaveMem(GlobalOptimizer):
return return
# Do not call make_node for test_value # Do not call make_node for test_value
new_outs = Scan(inps, outs, info, op.mode)( new_outs = Scan(inps, outs, info, op.mode)(*node_ins, return_list=True)
*node_ins, **dict(return_list=True)
)
old_new = [] old_new = []
# 3.7 Get replace pairs for those outputs that do not change # 3.7 Get replace pairs for those outputs that do not change
......
...@@ -771,7 +771,7 @@ class Rebroadcast(COp): ...@@ -771,7 +771,7 @@ class Rebroadcast(COp):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self(*eval_points, **dict(return_list=True)) return self(*eval_points, return_list=True)
def c_code(self, node, nodename, inp, out, sub): def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp (iname,) = inp
...@@ -1542,7 +1542,7 @@ class Alloc(COp): ...@@ -1542,7 +1542,7 @@ class Alloc(COp):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self(eval_points[0], *inputs[1:], **dict(return_list=True)) return self(eval_points[0], *inputs[1:], return_list=True)
def do_constant_folding(self, fgraph, node): def do_constant_folding(self, fgraph, node):
clients = fgraph.clients[node.outputs[0]] clients = fgraph.clients[node.outputs[0]]
...@@ -1945,7 +1945,7 @@ class Split(COp): ...@@ -1945,7 +1945,7 @@ class Split(COp):
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
"""Join the gradients along the axis that was used to split x.""" """Join the gradients along the axis that was used to split x."""
x, axis, n = inputs x, axis, n = inputs
outputs = self(*inputs, **dict(return_list=True)) outputs = self(*inputs, return_list=True)
# If all the output gradients are disconnected, then so are the inputs # If all the output gradients are disconnected, then so are the inputs
if builtins.all([isinstance(g.type, DisconnectedType) for g in g_outputs]): if builtins.all([isinstance(g.type, DisconnectedType) for g in g_outputs]):
return [ return [
......
...@@ -465,7 +465,7 @@ class InplaceElemwiseOptimizer(GlobalOptimizer): ...@@ -465,7 +465,7 @@ class InplaceElemwiseOptimizer(GlobalOptimizer):
) )
) )
new_outputs = self.op(new_scal, inplace_pattern)( new_outputs = self.op(new_scal, inplace_pattern)(
*node.inputs, **dict(return_list=True) *node.inputs, return_list=True
) )
new_node = new_outputs[0].owner new_node = new_outputs[0].owner
...@@ -684,7 +684,7 @@ def local_dimshuffle_lift(fgraph, node): ...@@ -684,7 +684,7 @@ def local_dimshuffle_lift(fgraph, node):
new_inp = op.__class__(inp.type.broadcastable, op.new_order)(inp) new_inp = op.__class__(inp.type.broadcastable, op.new_order)(inp)
new_inputs.append(apply_local_dimshuffle_lift(fgraph, new_inp)) new_inputs.append(apply_local_dimshuffle_lift(fgraph, new_inp))
copy_stack_trace(node.outputs[0], new_inputs) copy_stack_trace(node.outputs[0], new_inputs)
ret = inode.op(*new_inputs, **dict(return_list=True)) ret = inode.op(*new_inputs, return_list=True)
return ret return ret
if inode and isinstance(inode.op, DimShuffle): if inode and isinstance(inode.op, DimShuffle):
new_order = [x == "x" and "x" or inode.op.new_order[x] for x in new_order] new_order = [x == "x" and "x" or inode.op.new_order[x] for x in new_order]
......
...@@ -274,7 +274,7 @@ class DimShuffle(ExternalCOp): ...@@ -274,7 +274,7 @@ class DimShuffle(ExternalCOp):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if None in eval_points: if None in eval_points:
return [None] return [None]
return self(*eval_points, **dict(return_list=True)) return self(*eval_points, return_list=True)
def grad(self, inp, grads): def grad(self, inp, grads):
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
...@@ -504,7 +504,7 @@ second dimension ...@@ -504,7 +504,7 @@ second dimension
return self.name return self.name
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
outs = self(*inputs, **dict(return_list=True)) outs = self(*inputs, return_list=True)
rval = [None for x in outs] rval = [None for x in outs]
# For each output # For each output
for idx, out in enumerate(outs): for idx, out in enumerate(outs):
......
...@@ -2462,7 +2462,7 @@ class Sum(CAReduceDtype): ...@@ -2462,7 +2462,7 @@ class Sum(CAReduceDtype):
# part of self # part of self
if None in eval_points: if None in eval_points:
return [None] return [None]
return self(*eval_points, **dict(return_list=True)) return self(*eval_points, return_list=True)
def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
......
...@@ -577,7 +577,7 @@ class Reshape(COp): ...@@ -577,7 +577,7 @@ class Reshape(COp):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self(eval_points[0], *inputs[1:], **dict(return_list=True)) return self(eval_points[0], *inputs[1:], return_list=True)
def infer_shape(self, fgraph, node, ishapes): def infer_shape(self, fgraph, node, ishapes):
from aesara.tensor.math import eq, maximum, mul from aesara.tensor.math import eq, maximum, mul
......
...@@ -1149,7 +1149,7 @@ class Subtensor(COp): ...@@ -1149,7 +1149,7 @@ class Subtensor(COp):
# (they should be defaulted to zeros_like by the global R_op) # (they should be defaulted to zeros_like by the global R_op)
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self(eval_points[0], *inputs[1:], **dict(return_list=True)) return self(eval_points[0], *inputs[1:], return_list=True)
class SubtensorPrinter: class SubtensorPrinter:
...@@ -1764,9 +1764,7 @@ class IncSubtensor(COp): ...@@ -1764,9 +1764,7 @@ class IncSubtensor(COp):
return [None] return [None]
# Again we ignore eval points for indices because incsubtensor is # Again we ignore eval points for indices because incsubtensor is
# not differentiable wrt to those # not differentiable wrt to those
return self( return self(eval_points[0], eval_points[1], *inputs[2:], return_list=True)
eval_points[0], eval_points[1], *inputs[2:], **dict(return_list=True)
)
def connection_pattern(self, node): def connection_pattern(self, node):
......
...@@ -41,7 +41,7 @@ class TestGpuBroadcast(test_elemwise.TestBroadcast): ...@@ -41,7 +41,7 @@ class TestGpuBroadcast(test_elemwise.TestBroadcast):
linkers = [PerformLinker, CLinker] linkers = [PerformLinker, CLinker]
def rand_cval(self, shp): def rand_cval(self, shp):
return rand_gpuarray(*shp, **dict(cls=gpuarray)) return rand_gpuarray(*shp, cls=gpuarray)
def test_elemwise_pow(): def test_elemwise_pow():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论