提交 f7a5312c authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

ifelse.py in pep8

上级 b385a7e9
...@@ -142,8 +142,8 @@ class IfElse(PureOp): ...@@ -142,8 +142,8 @@ class IfElse(PureOp):
gpu=False, gpu=False,
name='_'.join(name_tokens)) name='_'.join(name_tokens))
new_outs = new_ifelse(node.inputs[0], new_outs = new_ifelse(node.inputs[0],
*(new_ts_inputs + new_fs_inputs), *(new_ts_inputs + new_fs_inputs),
**dict(return_list=True)) **dict(return_list=True))
else: else:
new_outs = [] new_outs = []
...@@ -160,8 +160,8 @@ class IfElse(PureOp): ...@@ -160,8 +160,8 @@ class IfElse(PureOp):
def make_node(self, c, *args): def make_node(self, c, *args):
assert len(args) == 2 * self.n_outs, ( assert len(args) == 2 * self.n_outs, (
"Wrong number of arguments to make_node: " "Wrong number of arguments to make_node: "
"expected %d, got %d" % (2 * self.n_outs, len(args)) "expected %d, got %d" % (2 * self.n_outs, len(args))
) )
if not self.gpu: if not self.gpu:
# When gpu is true, we are given only cuda ndarrays, and we want # When gpu is true, we are given only cuda ndarrays, and we want
...@@ -328,10 +328,10 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -328,10 +328,10 @@ def ifelse(condition, then_branch, else_branch, name=None):
for then_branch_elem, else_branch_elem in izip(then_branch, else_branch): for then_branch_elem, else_branch_elem in izip(then_branch, else_branch):
if not isinstance(then_branch_elem, theano.Variable): if not isinstance(then_branch_elem, theano.Variable):
then_branch_elem = theano.tensor.as_tensor_variable( then_branch_elem = theano.tensor.as_tensor_variable(
then_branch_elem) then_branch_elem)
if not isinstance(else_branch_elem, theano.Variable): if not isinstance(else_branch_elem, theano.Variable):
else_branch_elem = theano.tensor.as_tensor_variable( else_branch_elem = theano.tensor.as_tensor_variable(
else_branch_elem) else_branch_elem)
if then_branch_elem.type != else_branch_elem.type: if then_branch_elem.type != else_branch_elem.type:
# If one of them is a TensorType, and the other one can be # If one of them is a TensorType, and the other one can be
...@@ -341,22 +341,22 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -341,22 +341,22 @@ def ifelse(condition, then_branch, else_branch, name=None):
if (isinstance(then_branch_elem.type, TensorType) if (isinstance(then_branch_elem.type, TensorType)
and not isinstance(else_branch_elem.type, TensorType)): and not isinstance(else_branch_elem.type, TensorType)):
else_branch_elem = then_branch_elem.type.filter_variable( else_branch_elem = then_branch_elem.type.filter_variable(
else_branch_elem) else_branch_elem)
elif (isinstance(else_branch_elem.type, TensorType) elif (isinstance(else_branch_elem.type, TensorType)
and not isinstance(then_branch_elem.type, TensorType)): and not isinstance(then_branch_elem.type, TensorType)):
then_branch_elem = else_branch_elem.type.filter_variable( then_branch_elem = else_branch_elem.type.filter_variable(
then_branch_elem) then_branch_elem)
if then_branch_elem.type != else_branch_elem.type: if then_branch_elem.type != else_branch_elem.type:
# If the types still don't match, there is a problem. # If the types still don't match, there is a problem.
raise TypeError( raise TypeError(
'The two branches should have identical types, but ' 'The two branches should have identical types, but '
'they are %s and %s respectively. This error could be ' 'they are %s and %s respectively. This error could be '
'raised if for example you provided a one element ' 'raised if for example you provided a one element '
'list on the `then` branch but a tensor on the `else` ' 'list on the `then` branch but a tensor on the `else` '
'branch.' % 'branch.' %
(then_branch_elem.type, else_branch_elem.type)) (then_branch_elem.type, else_branch_elem.type))
new_then_branch.append(then_branch_elem) new_then_branch.append(then_branch_elem)
new_else_branch.append(else_branch_elem) new_else_branch.append(else_branch_elem)
...@@ -396,7 +396,7 @@ def cond_make_inplace(node): ...@@ -396,7 +396,7 @@ def cond_make_inplace(node):
optdb.register('cond_make_inplace', opt.in2out(cond_make_inplace, optdb.register('cond_make_inplace', opt.in2out(cond_make_inplace,
ignore_newtrees=True), 95, 'fast_run', 'inplace') ignore_newtrees=True), 95, 'fast_run', 'inplace')
# XXX: Optimizations commented pending further debugging (certain optimizations # XXX: Optimizations commented pending further debugging (certain optimizations
# make computation less lazy than it should be currently). # make computation less lazy than it should be currently).
...@@ -460,7 +460,7 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node): ...@@ -460,7 +460,7 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node):
for inp in main_node.inputs: for inp in main_node.inputs:
all_inp_nodes.add(inp.owner) all_inp_nodes.add(inp.owner)
ifnodes = [x for x in list(all_inp_nodes) ifnodes = [x for x in list(all_inp_nodes)
if x and isinstance(x.op, IfElse)] if x and isinstance(x.op, IfElse)]
# if we have multiple ifs as inputs .. it all becomes quite complicated # if we have multiple ifs as inputs .. it all becomes quite complicated
# :) # :)
if len(ifnodes) != 1: if len(ifnodes) != 1:
...@@ -471,7 +471,7 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node): ...@@ -471,7 +471,7 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node):
ts = node.inputs[1:][:op.n_outs] ts = node.inputs[1:][:op.n_outs]
fs = node.inputs[1:][op.n_outs:] fs = node.inputs[1:][op.n_outs:]
outs = main_node.outputs # outs = main_node.outputs
mop = main_node.op mop = main_node.op
true_ins = [] true_ins = []
false_ins = [] false_ins = []
...@@ -486,8 +486,8 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node): ...@@ -486,8 +486,8 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node):
false_ins.append(x) false_ins.append(x)
true_eval = mop(*true_ins, **dict(return_list=True)) true_eval = mop(*true_ins, **dict(return_list=True))
false_eval = mop(*false_ins, **dict(return_list=True)) false_eval = mop(*false_ins, **dict(return_list=True))
#true_eval = clone(outs, replace = dict(zip(node.outputs, ts))) # true_eval = clone(outs, replace = dict(zip(node.outputs, ts)))
#false_eval = clone(outs, replace = dict(zip(node.outputs, fs))) # false_eval = clone(outs, replace = dict(zip(node.outputs, fs)))
nw_outs = ifelse(node.inputs[0], true_eval, false_eval, return_list=True) nw_outs = ifelse(node.inputs[0], true_eval, false_eval, return_list=True)
return nw_outs return nw_outs
...@@ -503,10 +503,10 @@ def cond_merge_ifs_true(node): ...@@ -503,10 +503,10 @@ def cond_merge_ifs_true(node):
replace = {} replace = {}
for idx, tval in enumerate(t_ins): for idx, tval in enumerate(t_ins):
if (tval.owner and isinstance(tval.owner.op, IfElse) and if (tval.owner and isinstance(tval.owner.op, IfElse) and
tval.owner.inputs[0] == node.inputs[0]): tval.owner.inputs[0] == node.inputs[0]):
ins_op = tval.owner.op ins_op = tval.owner.op
ins_t = tval.owner.inputs[1:][:ins_op.n_outs] ins_t = tval.owner.inputs[1:][:ins_op.n_outs]
replace[idx + 1] = ins_t[tval.owner.outputs.index(tval)] replace[idx + 1] = ins_t[tval.owner.outputs.index(tval)]
if len(replace.items()) == 0: if len(replace.items()) == 0:
return False return False
...@@ -527,10 +527,10 @@ def cond_merge_ifs_false(node): ...@@ -527,10 +527,10 @@ def cond_merge_ifs_false(node):
replace = {} replace = {}
for idx, fval in enumerate(f_ins): for idx, fval in enumerate(f_ins):
if (fval.owner and isinstance(fval.owner.op, IfElse) and if (fval.owner and isinstance(fval.owner.op, IfElse) and
fval.owner.inputs[0] == node.inputs[0]): fval.owner.inputs[0] == node.inputs[0]):
ins_op = fval.owner.op ins_op = fval.owner.op
ins_t = fval.owner.inputs[1:][ins_op.n_outs:] ins_t = fval.owner.inputs[1:][ins_op.n_outs:]
replace[idx + 1 + op.n_outs] = \ replace[idx + 1 + op.n_outs] = \
ins_t[fval.owner.outputs.index(fval)] ins_t[fval.owner.outputs.index(fval)]
if len(replace.items()) == 0: if len(replace.items()) == 0:
...@@ -555,7 +555,7 @@ class CondMerge(gof.Optimizer): ...@@ -555,7 +555,7 @@ class CondMerge(gof.Optimizer):
merging_node = cond_nodes[0] merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]: for proposal in cond_nodes[1:]:
if (proposal.inputs[0] == merging_node.inputs[0] and if (proposal.inputs[0] == merging_node.inputs[0] and
not find_up(proposal, merging_node)): not find_up(proposal, merging_node)):
# Create a list of replacements for proposal # Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs] mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:] mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
...@@ -567,8 +567,8 @@ class CondMerge(gof.Optimizer): ...@@ -567,8 +567,8 @@ class CondMerge(gof.Optimizer):
if merging_node.op.name: if merging_node.op.name:
mn_name = merging_node.op.name mn_name = merging_node.op.name
pl_name = '?' pl_name = '?'
mn_n_ts = len(mn_ts) # mn_n_ts = len(mn_ts)
mn_n_fs = len(mn_fs) # mn_n_fs = len(mn_fs)
if proposal.op.name: if proposal.op.name:
pl_name = proposal.op.name pl_name = proposal.op.name
new_ifelse = IfElse( new_ifelse = IfElse(
...@@ -607,8 +607,8 @@ def cond_remove_identical(node): ...@@ -607,8 +607,8 @@ def cond_remove_identical(node):
if idx not in out_map: if idx not in out_map:
for jdx in xrange(idx + 1, len(node.outputs)): for jdx in xrange(idx + 1, len(node.outputs)):
if (ts[idx] == ts[jdx] and if (ts[idx] == ts[jdx] and
fs[idx] == fs[jdx] and fs[idx] == fs[jdx] and
jdx not in out_map): jdx not in out_map):
out_map[jdx] = idx out_map[jdx] = idx
if len(out_map.keys()) == 0: if len(out_map.keys()) == 0:
...@@ -652,7 +652,7 @@ def cond_merge_random_op(main_node): ...@@ -652,7 +652,7 @@ def cond_merge_random_op(main_node):
for inp in main_node.inputs: for inp in main_node.inputs:
all_inp_nodes.add(inp.owner) all_inp_nodes.add(inp.owner)
cond_nodes = [x for x in list(all_inp_nodes) cond_nodes = [x for x in list(all_inp_nodes)
if x and isinstance(x.op, IfElse)] if x and isinstance(x.op, IfElse)]
if len(cond_nodes) < 2: if len(cond_nodes) < 2:
return False return False
...@@ -660,8 +660,8 @@ def cond_merge_random_op(main_node): ...@@ -660,8 +660,8 @@ def cond_merge_random_op(main_node):
merging_node = cond_nodes[0] merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]: for proposal in cond_nodes[1:]:
if (proposal.inputs[0] == merging_node.inputs[0] and if (proposal.inputs[0] == merging_node.inputs[0] and
not find_up(proposal, merging_node) and not find_up(proposal, merging_node) and
not find_up(merging_node, proposal)): not find_up(merging_node, proposal)):
# Create a list of replacements for proposal # Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs] mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:] mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
...@@ -673,8 +673,8 @@ def cond_merge_random_op(main_node): ...@@ -673,8 +673,8 @@ def cond_merge_random_op(main_node):
if merging_node.op.name: if merging_node.op.name:
mn_name = merging_node.op.name mn_name = merging_node.op.name
pl_name = '?' pl_name = '?'
mn_n_ts = len(mn_ts) # mn_n_ts = len(mn_ts)
mn_n_fs = len(mn_fs) # mn_n_fs = len(mn_fs)
if proposal.op.name: if proposal.op.name:
pl_name = proposal.op.name pl_name = proposal.op.name
new_ifelse = IfElse( new_ifelse = IfElse(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论