提交 f7a5312c authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

ifelse.py in pep8

上级 b385a7e9
......@@ -142,8 +142,8 @@ class IfElse(PureOp):
gpu=False,
name='_'.join(name_tokens))
new_outs = new_ifelse(node.inputs[0],
*(new_ts_inputs + new_fs_inputs),
**dict(return_list=True))
*(new_ts_inputs + new_fs_inputs),
**dict(return_list=True))
else:
new_outs = []
......@@ -160,8 +160,8 @@ class IfElse(PureOp):
def make_node(self, c, *args):
assert len(args) == 2 * self.n_outs, (
"Wrong number of arguments to make_node: "
"expected %d, got %d" % (2 * self.n_outs, len(args))
"Wrong number of arguments to make_node: "
"expected %d, got %d" % (2 * self.n_outs, len(args))
)
if not self.gpu:
# When gpu is true, we are given only cuda ndarrays, and we want
......@@ -328,10 +328,10 @@ def ifelse(condition, then_branch, else_branch, name=None):
for then_branch_elem, else_branch_elem in izip(then_branch, else_branch):
if not isinstance(then_branch_elem, theano.Variable):
then_branch_elem = theano.tensor.as_tensor_variable(
then_branch_elem)
then_branch_elem)
if not isinstance(else_branch_elem, theano.Variable):
else_branch_elem = theano.tensor.as_tensor_variable(
else_branch_elem)
else_branch_elem)
if then_branch_elem.type != else_branch_elem.type:
# If one of them is a TensorType, and the other one can be
......@@ -341,22 +341,22 @@ def ifelse(condition, then_branch, else_branch, name=None):
if (isinstance(then_branch_elem.type, TensorType)
and not isinstance(else_branch_elem.type, TensorType)):
else_branch_elem = then_branch_elem.type.filter_variable(
else_branch_elem)
else_branch_elem)
elif (isinstance(else_branch_elem.type, TensorType)
and not isinstance(then_branch_elem.type, TensorType)):
then_branch_elem = else_branch_elem.type.filter_variable(
then_branch_elem)
then_branch_elem)
if then_branch_elem.type != else_branch_elem.type:
# If the types still don't match, there is a problem.
raise TypeError(
'The two branches should have identical types, but '
'they are %s and %s respectively. This error could be '
'raised if for example you provided a one element '
'list on the `then` branch but a tensor on the `else` '
'branch.' %
(then_branch_elem.type, else_branch_elem.type))
'The two branches should have identical types, but '
'they are %s and %s respectively. This error could be '
'raised if for example you provided a one element '
'list on the `then` branch but a tensor on the `else` '
'branch.' %
(then_branch_elem.type, else_branch_elem.type))
new_then_branch.append(then_branch_elem)
new_else_branch.append(else_branch_elem)
......@@ -396,7 +396,7 @@ def cond_make_inplace(node):
optdb.register('cond_make_inplace', opt.in2out(cond_make_inplace,
ignore_newtrees=True), 95, 'fast_run', 'inplace')
ignore_newtrees=True), 95, 'fast_run', 'inplace')
# XXX: Optimizations commented pending further debugging (certain optimizations
# make computation less lazy than it should be currently).
......@@ -460,7 +460,7 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node):
for inp in main_node.inputs:
all_inp_nodes.add(inp.owner)
ifnodes = [x for x in list(all_inp_nodes)
if x and isinstance(x.op, IfElse)]
if x and isinstance(x.op, IfElse)]
# if we have multiple ifs as inputs .. it all becomes quite complicated
# :)
if len(ifnodes) != 1:
......@@ -471,7 +471,7 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node):
ts = node.inputs[1:][:op.n_outs]
fs = node.inputs[1:][op.n_outs:]
outs = main_node.outputs
# outs = main_node.outputs
mop = main_node.op
true_ins = []
false_ins = []
......@@ -486,8 +486,8 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node):
false_ins.append(x)
true_eval = mop(*true_ins, **dict(return_list=True))
false_eval = mop(*false_ins, **dict(return_list=True))
#true_eval = clone(outs, replace = dict(zip(node.outputs, ts)))
#false_eval = clone(outs, replace = dict(zip(node.outputs, fs)))
# true_eval = clone(outs, replace = dict(zip(node.outputs, ts)))
# false_eval = clone(outs, replace = dict(zip(node.outputs, fs)))
nw_outs = ifelse(node.inputs[0], true_eval, false_eval, return_list=True)
return nw_outs
......@@ -503,10 +503,10 @@ def cond_merge_ifs_true(node):
replace = {}
for idx, tval in enumerate(t_ins):
if (tval.owner and isinstance(tval.owner.op, IfElse) and
tval.owner.inputs[0] == node.inputs[0]):
ins_op = tval.owner.op
ins_t = tval.owner.inputs[1:][:ins_op.n_outs]
replace[idx + 1] = ins_t[tval.owner.outputs.index(tval)]
tval.owner.inputs[0] == node.inputs[0]):
ins_op = tval.owner.op
ins_t = tval.owner.inputs[1:][:ins_op.n_outs]
replace[idx + 1] = ins_t[tval.owner.outputs.index(tval)]
if len(replace.items()) == 0:
return False
......@@ -527,10 +527,10 @@ def cond_merge_ifs_false(node):
replace = {}
for idx, fval in enumerate(f_ins):
if (fval.owner and isinstance(fval.owner.op, IfElse) and
fval.owner.inputs[0] == node.inputs[0]):
ins_op = fval.owner.op
ins_t = fval.owner.inputs[1:][ins_op.n_outs:]
replace[idx + 1 + op.n_outs] = \
fval.owner.inputs[0] == node.inputs[0]):
ins_op = fval.owner.op
ins_t = fval.owner.inputs[1:][ins_op.n_outs:]
replace[idx + 1 + op.n_outs] = \
ins_t[fval.owner.outputs.index(fval)]
if len(replace.items()) == 0:
......@@ -555,7 +555,7 @@ class CondMerge(gof.Optimizer):
merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]:
if (proposal.inputs[0] == merging_node.inputs[0] and
not find_up(proposal, merging_node)):
not find_up(proposal, merging_node)):
# Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
......@@ -567,8 +567,8 @@ class CondMerge(gof.Optimizer):
if merging_node.op.name:
mn_name = merging_node.op.name
pl_name = '?'
mn_n_ts = len(mn_ts)
mn_n_fs = len(mn_fs)
# mn_n_ts = len(mn_ts)
# mn_n_fs = len(mn_fs)
if proposal.op.name:
pl_name = proposal.op.name
new_ifelse = IfElse(
......@@ -607,8 +607,8 @@ def cond_remove_identical(node):
if idx not in out_map:
for jdx in xrange(idx + 1, len(node.outputs)):
if (ts[idx] == ts[jdx] and
fs[idx] == fs[jdx] and
jdx not in out_map):
fs[idx] == fs[jdx] and
jdx not in out_map):
out_map[jdx] = idx
if len(out_map.keys()) == 0:
......@@ -652,7 +652,7 @@ def cond_merge_random_op(main_node):
for inp in main_node.inputs:
all_inp_nodes.add(inp.owner)
cond_nodes = [x for x in list(all_inp_nodes)
if x and isinstance(x.op, IfElse)]
if x and isinstance(x.op, IfElse)]
if len(cond_nodes) < 2:
return False
......@@ -660,8 +660,8 @@ def cond_merge_random_op(main_node):
merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]:
if (proposal.inputs[0] == merging_node.inputs[0] and
not find_up(proposal, merging_node) and
not find_up(merging_node, proposal)):
not find_up(proposal, merging_node) and
not find_up(merging_node, proposal)):
# Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
......@@ -673,8 +673,8 @@ def cond_merge_random_op(main_node):
if merging_node.op.name:
mn_name = merging_node.op.name
pl_name = '?'
mn_n_ts = len(mn_ts)
mn_n_fs = len(mn_fs)
# mn_n_ts = len(mn_ts)
# mn_n_fs = len(mn_fs)
if proposal.op.name:
pl_name = proposal.op.name
new_ifelse = IfElse(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论