提交 ce690a69 authored 作者: Razvan Pascanu's avatar Razvan Pascanu 提交者: David Warde-Farley

Made the code pep8 compatible

上级 8c08c76e
""" """
IfElse is an Op that works with the LazyLinker to support conditional graph evaluation. IfElse is an Op that works with the LazyLinker to support conditional graph
evaluation.
:TODO: Add text to library documentation describing the IfElse Op. :TODO: Add text to library documentation describing the IfElse Op.
""" """
__docformat__ = 'restructedtext en' __docformat__ = 'restructedtext en'
__authors__ = ( "Razvan Pascanu " __authors__ = ("Razvan Pascanu "
"James Bergstra " "James Bergstra "
"Dumitru Erhan ") "Dumitru Erhan ")
__copyright__ = "(c) 2010, Universite de Montreal" __copyright__ = "(c) 2010, Universite de Montreal"
...@@ -27,8 +28,6 @@ from scan_module.scan_utils import clone ...@@ -27,8 +28,6 @@ from scan_module.scan_utils import clone
_logger = logging.getLogger('theano.ifelse') _logger = logging.getLogger('theano.ifelse')
class IfElse(PureOp): class IfElse(PureOp):
""" """
Op that works with CVM/VM to support conditional graph evaluation. Note, Op that works with CVM/VM to support conditional graph evaluation. Note,
...@@ -47,13 +46,13 @@ class IfElse(PureOp): ...@@ -47,13 +46,13 @@ class IfElse(PureOp):
False branch before picking one. False branch before picking one.
""" """
def __init__(self, n_outs, as_view=False, gpu = False, name = None): def __init__(self, n_outs, as_view=False, gpu=False, name=None):
if as_view: if as_view:
# check destroyhandler and others to ensure that a view_map with # check destroyhandler and others to ensure that a view_map with
# multiple inputs can work # multiple inputs can work
view_map = {} view_map = {}
for idx in xrange(n_outs): for idx in xrange(n_outs):
view_map[idx] = [idx+1] view_map[idx] = [idx + 1]
self.view_map = view_map self.view_map = view_map
#raise NotImplementedError('Cond must copy for now') #raise NotImplementedError('Cond must copy for now')
self.as_view = as_view self.as_view = as_view
...@@ -61,7 +60,6 @@ class IfElse(PureOp): ...@@ -61,7 +60,6 @@ class IfElse(PureOp):
self.n_outs = n_outs self.n_outs = n_outs
self.name = name self.name = name
def __eq__(self, other): def __eq__(self, other):
if not type(self) == type(other): if not type(self) == type(other):
return False return False
...@@ -75,7 +73,7 @@ class IfElse(PureOp): ...@@ -75,7 +73,7 @@ class IfElse(PureOp):
def __hash__(self): def __hash__(self):
rval = ( hash(type(self)) ^ rval = (hash(type(self)) ^
hash(self.as_view) ^ hash(self.as_view) ^
hash(self.gpu) ^ hash(self.gpu) ^
hash(self.n_outs)) hash(self.n_outs))
...@@ -83,7 +81,7 @@ class IfElse(PureOp): ...@@ -83,7 +81,7 @@ class IfElse(PureOp):
return rval return rval
def __str__(self): def __str__(self):
name ='if{%s'%str(self.name) name = 'if{%s' % str(self.name)
if self.as_view: if self.as_view:
name += ',inplace' name += ',inplace'
if self.gpu: if self.gpu:
...@@ -91,7 +89,6 @@ class IfElse(PureOp): ...@@ -91,7 +89,6 @@ class IfElse(PureOp):
name += '}' name += '}'
return name return name
def infer_shape(self, node, inputs_shapes): def infer_shape(self, node, inputs_shapes):
# By construction, corresponding then/else pairs have the same number # By construction, corresponding then/else pairs have the same number
# of dimensions # of dimensions
...@@ -103,7 +100,7 @@ class IfElse(PureOp): ...@@ -103,7 +100,7 @@ class IfElse(PureOp):
new_ts_inputs += list(ts_shape) new_ts_inputs += list(ts_shape)
else: else:
# It can be None for generic objects # It can be None for generic objects
return [None]*self.n_outs return [None] * self.n_outs
new_fs_inputs = [] new_fs_inputs = []
for fs_shape in fs_shapes: for fs_shape in fs_shapes:
...@@ -111,18 +108,18 @@ class IfElse(PureOp): ...@@ -111,18 +108,18 @@ class IfElse(PureOp):
new_fs_inputs += list(fs_shape) new_fs_inputs += list(fs_shape)
else: else:
# It can be None for generic objects # It can be None for generic objects
return [None]*self.n_outs return [None] * self.n_outs
assert len(new_ts_inputs) == len(new_fs_inputs) assert len(new_ts_inputs) == len(new_fs_inputs)
if len(new_ts_inputs + new_fs_inputs) > 0: if len(new_ts_inputs + new_fs_inputs) > 0:
new_ifelse = IfElse( new_ifelse = IfElse(
n_outs = len(new_ts_inputs), n_outs=len(new_ts_inputs),
as_view = False, as_view=False,
gpu = False, gpu=False,
name='shape_'+str(self.name)) name='shape_' + str(self.name))
new_outs = new_ifelse.make_node(node.inputs[0], new_outs = new_ifelse.make_node(node.inputs[0],
*(new_ts_inputs+new_fs_inputs)).outputs *(new_ts_inputs + new_fs_inputs)).outputs
else: else:
new_outs = [] new_outs = []
...@@ -138,12 +135,8 @@ class IfElse(PureOp): ...@@ -138,12 +135,8 @@ class IfElse(PureOp):
out_shapes += [tuple(current_shape)] out_shapes += [tuple(current_shape)]
return out_shapes return out_shapes
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return self.make_node(inputs[0],*eval_points[1:]).outputs return self.make_node(inputs[0], *eval_points[1:]).outputs
def make_node(self, c, *args): def make_node(self, c, *args):
if not self.gpu: if not self.gpu:
...@@ -160,15 +153,15 @@ class IfElse(PureOp): ...@@ -160,15 +153,15 @@ class IfElse(PureOp):
ts = args[:self.n_outs] ts = args[:self.n_outs]
fs = args[self.n_outs:] fs = args[self.n_outs:]
for t,f in zip(ts, fs): for t, f in zip(ts, fs):
if t.type != f.type: if t.type != f.type:
raise TypeError(('IfElse requires same types for true and ' raise TypeError(('IfElse requires same types for true and '
'false return values'), t, f, t.type, f.type) 'false return values'), t, f, t.type, f.type)
if c.ndim >0: if c.ndim > 0:
raise TypeError(('Condition given to the op has to be a scalar ' raise TypeError(('Condition given to the op has to be a scalar '
'with 0 standing for False, anything else for True')) 'with 0 standing for False, anything else '
return Apply(self, [c]+list(args), [t.type() for t in ts]) 'for True'))
return Apply(self, [c] + list(args), [t.type() for t in ts])
def grad(self, ins, grads): def grad(self, ins, grads):
ts = ins[1:][:self.n_outs] ts = ins[1:][:self.n_outs]
...@@ -179,43 +172,39 @@ class IfElse(PureOp): ...@@ -179,43 +172,39 @@ class IfElse(PureOp):
else: else:
nw_name_t = None nw_name_t = None
nw_name_f = None nw_name_f = None
if_true_op = IfElse(n_outs = self.n_outs, if_true_op = IfElse(n_outs=self.n_outs,
as_view = self.as_view, as_view=self.as_view,
gpu = self.gpu, gpu=self.gpu,
name = nw_name_t) name=nw_name_t)
if_false_op = IfElse(n_outs = self.n_outs, if_false_op = IfElse(n_outs=self.n_outs,
as_view = self.as_view, as_view=self.as_view,
gpu = self.gpu, gpu=self.gpu,
name = nw_name_f) name=nw_name_f)
if_true = ([ins[0]]+ grads+ [theano.tensor.zeros_like(t) if_true = ([ins[0]] + grads + [theano.tensor.zeros_like(t)
for t in ts]) for t in ts])
if_false = ([ins[0]] + [theano.tensor.zeros_like(f) if_false = ([ins[0]] + [theano.tensor.zeros_like(f)
for f in fs] + grads) for f in fs] + grads)
return ([None]+ return ([None] +
if_true_op.make_node(*if_true).outputs + if_true_op.make_node(*if_true).outputs +
if_false_op.make_node(*if_false).outputs ) if_false_op.make_node(*if_false).outputs)
def make_thunk(self, node, storage_map, compute_map, no_recycling): def make_thunk(self, node, storage_map, compute_map, no_recycling):
outtypes = [ out.type for out in node.outputs] outtypes = [out.type for out in node.outputs]
cond = node.inputs[0] cond = node.inputs[0]
ts = node.inputs[1:][:self.n_outs] ts = node.inputs[1:][:self.n_outs]
fs = node.inputs[1:][self.n_outs:] fs = node.inputs[1:][self.n_outs:]
outputs = node.outputs outputs = node.outputs
def thunk(): def thunk():
if not compute_map[cond][0]: if not compute_map[cond][0]:
return [0] return [0]
else: else:
truthval = storage_map[cond][0] truthval = storage_map[cond][0]
if truthval != 0: if truthval != 0:
ls = [idx+1 for idx in xrange(self.n_outs) ls = [idx + 1 for idx in xrange(self.n_outs)
if not compute_map[ts[idx]][0]] if not compute_map[ts[idx]][0]]
if len(ls) > 0: if len(ls) > 0:
return ls return ls
...@@ -232,7 +221,7 @@ class IfElse(PureOp): ...@@ -232,7 +221,7 @@ class IfElse(PureOp):
storage_map[out][0] = oval storage_map[out][0] = oval
return [] return []
else: else:
ls = [1+idx+self.n_outs for idx in xrange(self.n_outs) ls = [1 + idx + self.n_outs for idx in xrange(self.n_outs)
if not compute_map[fs[idx]][0]] if not compute_map[fs[idx]][0]]
if len(ls) > 0: if len(ls) > 0:
return ls return ls
...@@ -249,12 +238,12 @@ class IfElse(PureOp): ...@@ -249,12 +238,12 @@ class IfElse(PureOp):
return [] return []
thunk.lazy = True thunk.lazy = True
thunk.inputs = [ storage_map[v] for v in node.inputs] thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [ storage_map[v] for v in node.outputs] thunk.outputs = [storage_map[v] for v in node.outputs]
return thunk return thunk
def ifelse( cond, true_branch, false_branch, name = None): def ifelse(cond, true_branch, false_branch, name=None):
""" """
This function corresponds to a if statement, returning inputs in the This function corresponds to a if statement, returning inputs in the
``true_branch`` if ``cond`` evaluates to True or inputs in the ``true_branch`` if ``cond`` evaluates to True or inputs in the
...@@ -293,34 +282,33 @@ def ifelse( cond, true_branch, false_branch, name = None): ...@@ -293,34 +282,33 @@ def ifelse( cond, true_branch, false_branch, name = None):
false_branch = [false_branch] false_branch = [false_branch]
if len(true_branch) != len(false_branch): if len(true_branch) != len(false_branch):
raise ValueError(( 'The number of values on the `then` branch'+ raise ValueError(('The number of values on the `then` branch'
' should have the same number of variables as '+ ' should have the same number of variables as '
'the `else` branch : (variables on `then` '+ 'the `else` branch : (variables on `then` '
'%d'%len(true_branch)+ ', variables on `else` '+ '%d' % len(true_branch) + ', variables on `else` '
'%d'%len(false_branch)+')')) '%d' % len(false_branch) + ')'))
new_ifelse = IfElse(n_outs = len(true_branch), new_ifelse = IfElse(n_outs=len(true_branch),
as_view=False, as_view=False,
gpu = False, gpu=False,
name = name) name=name)
ins = [cond] + list(true_branch) + list(false_branch) ins = [cond] + list(true_branch) + list(false_branch)
rval = new_ifelse.make_node(*ins).outputs rval = new_ifelse.make_node(*ins).outputs
if type(rval) in (list,tuple) and len(rval) == 1: if type(rval) in (list, tuple) and len(rval) == 1:
return rval[0] return rval[0]
else: else:
return rval return rval
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def cond_make_inplace(node): def cond_make_inplace(node):
op = node.op op = node.op
if isinstance(op, IfElse) and not op.as_view : if isinstance(op, IfElse) and not op.as_view:
return IfElse(n_outs = op.n_outs, return IfElse(n_outs=op.n_outs,
as_view = True, as_view=True,
gpu = op.gpu, gpu=op.gpu,
name = op.name ).make_node(*node.inputs).outputs name=op.name).make_node(*node.inputs).outputs
return False return False
optdb.register('cond_make_inplace', opt.in2out(cond_make_inplace, optdb.register('cond_make_inplace', opt.in2out(cond_make_inplace,
...@@ -330,7 +318,8 @@ ifelse_equilibrium = gof.EquilibriumDB() ...@@ -330,7 +318,8 @@ ifelse_equilibrium = gof.EquilibriumDB()
ifelse_seqopt = gof.SequenceDB() ifelse_seqopt = gof.SequenceDB()
ifelse_equilibrium.register('seq_ifelse', ifelse_seqopt, 'fast_run', ifelse_equilibrium.register('seq_ifelse', ifelse_seqopt, 'fast_run',
'ifelse') 'ifelse')
optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, .5, 'fast_run', 'ifelse') optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, .5, 'fast_run',
'ifelse')
@gof.local_optimizer([None]) @gof.local_optimizer([None])
...@@ -341,18 +330,18 @@ def cond_merge_ifs_true(node): ...@@ -341,18 +330,18 @@ def cond_merge_ifs_true(node):
t_ins = node.inputs[1:][:op.n_outs] t_ins = node.inputs[1:][:op.n_outs]
replace = {} replace = {}
for idx,tval in enumerate(t_ins): for idx, tval in enumerate(t_ins):
if (tval.owner and isinstance(tval.owner.op, IfElse) and if (tval.owner and isinstance(tval.owner.op, IfElse) and
tval.owner.inputs[0] == node.inputs[0]): tval.owner.inputs[0] == node.inputs[0]):
ins_op = tval.owner.op ins_op = tval.owner.op
ins_t = tval.owner.inputs[1:][:ins_op.n_outs] ins_t = tval.owner.inputs[1:][:ins_op.n_outs]
replace[idx+1] = ins_t[tval.owner.outputs.index(tval)] replace[idx + 1] = ins_t[tval.owner.outputs.index(tval)]
if len(replace.items()) == 0: if len(replace.items()) == 0:
return False return False
old_ins = list(node.inputs) old_ins = list(node.inputs)
for pos,var in replace.items(): for pos, var in replace.items():
old_ins[pos] = var old_ins[pos] = var
return op.make_node(*old_ins).outputs return op.make_node(*old_ins).outputs
...@@ -365,24 +354,26 @@ def cond_merge_ifs_false(node): ...@@ -365,24 +354,26 @@ def cond_merge_ifs_false(node):
f_ins = node.inputs[1:][op.n_outs:] f_ins = node.inputs[1:][op.n_outs:]
replace = {} replace = {}
for idx,fval in enumerate(f_ins): for idx, fval in enumerate(f_ins):
if (fval.owner and isinstance(fval.owner.op, IfElse) and if (fval.owner and isinstance(fval.owner.op, IfElse) and
fval.owner.inputs[0] == node.inputs[0]): fval.owner.inputs[0] == node.inputs[0]):
ins_op = fval.owner.op ins_op = fval.owner.op
ins_t = fval.owner.inputs[1:][ins_op.n_outs:] ins_t = fval.owner.inputs[1:][ins_op.n_outs:]
replace[idx+1+op.n_outs] = ins_t[fval.owner.outputs.index(fval)] replace[idx + 1 + op.n_outs] = \
ins_t[fval.owner.outputs.index(fval)]
if len(replace.items()) == 0: if len(replace.items()) == 0:
return False return False
old_ins = list(node.inputs) old_ins = list(node.inputs)
for pos,var in replace.items(): for pos, var in replace.items():
old_ins[pos] = var old_ins[pos] = var
return op.make_node(*old_ins).outputs return op.make_node(*old_ins).outputs
class CondMerge(gof.Optimizer): class CondMerge(gof.Optimizer):
""" Graph Optimizer that merges different cond ops """ """ Graph Optimizer that merges different cond ops """
def add_requirements(self,env): def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate()) env.extend(gof.toolbox.ReplaceValidate())
def apply(self, env): def apply(self, env):
...@@ -400,7 +391,7 @@ class CondMerge(gof.Optimizer): ...@@ -400,7 +391,7 @@ class CondMerge(gof.Optimizer):
pl_ts = proposal.inputs[1:][:proposal.op.n_outs] pl_ts = proposal.inputs[1:][:proposal.op.n_outs]
pl_fs = proposal.inputs[1:][proposal.op.n_outs:] pl_fs = proposal.inputs[1:][proposal.op.n_outs:]
new_ins = ([merging_node.inputs[0]] + new_ins = ([merging_node.inputs[0]] +
mn_ts + pl_ts + mn_fs + pl_fs ) mn_ts + pl_ts + mn_fs + pl_fs)
mn_name = '?' mn_name = '?'
if merging_node.op.name: if merging_node.op.name:
mn_name = merging_node.op.name mn_name = merging_node.op.name
...@@ -410,10 +401,10 @@ class CondMerge(gof.Optimizer): ...@@ -410,10 +401,10 @@ class CondMerge(gof.Optimizer):
if proposal.op.name: if proposal.op.name:
pl_name = proposal.op.name pl_name = proposal.op.name
new_ifelse = IfElse( new_ifelse = IfElse(
n_outs = len(mn_ts+pl_ts), n_outs=len(mn_ts + pl_ts),
as_view=False, as_view=False,
gpu = False, gpu=False,
name = mn_name+'&'+pl_name) name=mn_name + '&' + pl_name)
print 'here' print 'here'
new_outs = new_ifelse.make_node(*new_ins).outputs new_outs = new_ifelse.make_node(*new_ins).outputs
new_outs = [clone(x) for x in new_outs] new_outs = [clone(x) for x in new_outs]
...@@ -443,7 +434,7 @@ def cond_remove_identical(node): ...@@ -443,7 +434,7 @@ def cond_remove_identical(node):
out_map = {} out_map = {}
for idx in xrange(len(node.outputs)): for idx in xrange(len(node.outputs)):
if idx not in out_map: if idx not in out_map:
for jdx in xrange(idx+1,len(node.outputs)): for jdx in xrange(idx + 1, len(node.outputs)):
if (ts[idx] == ts[jdx] and if (ts[idx] == ts[jdx] and
fs[idx] == fs[jdx] and fs[idx] == fs[jdx] and
jdx not in out_map): jdx not in out_map):
...@@ -463,12 +454,10 @@ def cond_remove_identical(node): ...@@ -463,12 +454,10 @@ def cond_remove_identical(node):
nw_ts.append(ts[idx]) nw_ts.append(ts[idx])
nw_fs.append(fs[idx]) nw_fs.append(fs[idx])
new_ifelse = IfElse(n_outs = len(nw_ts), new_ifelse = IfElse(n_outs=len(nw_ts),
as_view = op.as_view, as_view=op.as_view,
gpu = op.gpu, gpu=op.gpu,
name = op.name) name=op.name)
new_ins = [node.inputs[0]] + nw_ts + nw_fs new_ins = [node.inputs[0]] + nw_ts + nw_fs
new_outs = new_ifelse.make_node(*new_ins).outputs new_outs = new_ifelse.make_node(*new_ins).outputs
...@@ -476,14 +465,12 @@ def cond_remove_identical(node): ...@@ -476,14 +465,12 @@ def cond_remove_identical(node):
rval = [] rval = []
for idx in xrange(len(node.outputs)): for idx in xrange(len(node.outputs)):
if idx in out_map.keys(): if idx in out_map.keys():
rval += [new_outs[inv_map[out_map[idx]]] ] rval += [new_outs[inv_map[out_map[idx]]]]
else: else:
rval += [new_outs[inv_map[idx]]] rval += [new_outs[inv_map[idx]]]
return rval return rval
acceptable_ops = (theano.tensor.basic.Dot, acceptable_ops = (theano.tensor.basic.Dot,
theano.tensor.basic.Reshape, theano.tensor.basic.Reshape,
theano.tensor.basic.Shape, theano.tensor.basic.Shape,
...@@ -494,7 +481,7 @@ acceptable_ops = (theano.tensor.basic.Dot, ...@@ -494,7 +481,7 @@ acceptable_ops = (theano.tensor.basic.Dot,
theano.tensor.basic.Rebroadcast, theano.tensor.basic.Rebroadcast,
theano.tensor.basic.Alloc, theano.tensor.basic.Alloc,
theano.tensor.elemwise.Elemwise, theano.tensor.elemwise.Elemwise,
theano.tensor.elemwise.DimShuffle ) theano.tensor.elemwise.DimShuffle)
@gof.local_optimizer([None]) @gof.local_optimizer([None])
...@@ -504,7 +491,8 @@ def cond_lift_single_if(main_node): ...@@ -504,7 +491,8 @@ def cond_lift_single_if(main_node):
all_inp_nodes = set() all_inp_nodes = set()
for inp in main_node.inputs: for inp in main_node.inputs:
all_inp_nodes.add(inp.owner) all_inp_nodes.add(inp.owner)
ifnodes = [ x for x in list(all_inp_nodes) if x and isinstance(x.op, IfElse)] ifnodes = [x for x in list(all_inp_nodes)
if x and isinstance(x.op, IfElse)]
# if we have multiple ifs as inputs .. it all becomes quite complicated # if we have multiple ifs as inputs .. it all becomes quite complicated
# :) # :)
if len(ifnodes) != 1: if len(ifnodes) != 1:
...@@ -539,7 +527,6 @@ def cond_lift_single_if(main_node): ...@@ -539,7 +527,6 @@ def cond_lift_single_if(main_node):
return nw_outs return nw_outs
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def cond_merge_random_op(main_node): def cond_merge_random_op(main_node):
if isinstance(main_node.op, IfElse): if isinstance(main_node.op, IfElse):
...@@ -548,7 +535,8 @@ def cond_merge_random_op(main_node): ...@@ -548,7 +535,8 @@ def cond_merge_random_op(main_node):
all_inp_nodes = set() all_inp_nodes = set()
for inp in main_node.inputs: for inp in main_node.inputs:
all_inp_nodes.add(inp.owner) all_inp_nodes.add(inp.owner)
cond_nodes = [ x for x in list(all_inp_nodes) if x and isinstance(x.op, IfElse)] cond_nodes = [x for x in list(all_inp_nodes)
if x and isinstance(x.op, IfElse)]
if len(cond_nodes) < 2: if len(cond_nodes) < 2:
return False return False
...@@ -564,7 +552,7 @@ def cond_merge_random_op(main_node): ...@@ -564,7 +552,7 @@ def cond_merge_random_op(main_node):
pl_ts = proposal.inputs[1:][:proposal.op.n_outs] pl_ts = proposal.inputs[1:][:proposal.op.n_outs]
pl_fs = proposal.inputs[1:][proposal.op.n_outs:] pl_fs = proposal.inputs[1:][proposal.op.n_outs:]
new_ins = ([merging_node.inputs[0]] + new_ins = ([merging_node.inputs[0]] +
mn_ts + pl_ts + mn_fs + pl_fs ) mn_ts + pl_ts + mn_fs + pl_fs)
mn_name = '?' mn_name = '?'
if merging_node.op.name: if merging_node.op.name:
mn_name = merging_node.op.name mn_name = merging_node.op.name
...@@ -574,10 +562,10 @@ def cond_merge_random_op(main_node): ...@@ -574,10 +562,10 @@ def cond_merge_random_op(main_node):
if proposal.op.name: if proposal.op.name:
pl_name = proposal.op.name pl_name = proposal.op.name
new_ifelse = IfElse( new_ifelse = IfElse(
n_outs = len(mn_ts+pl_ts), n_outs=len(mn_ts + pl_ts),
as_view=False, as_view=False,
gpu = False, gpu=False,
name = mn_name+'&'+pl_name) name=mn_name + '&' + pl_name)
new_outs = new_ifelse.make_node(*new_ins).outputs new_outs = new_ifelse.make_node(*new_ins).outputs
old_outs = [] old_outs = []
if type(merging_node.outputs) not in (list, tuple): if type(merging_node.outputs) not in (list, tuple):
...@@ -593,19 +581,16 @@ def cond_merge_random_op(main_node): ...@@ -593,19 +581,16 @@ def cond_merge_random_op(main_node):
return main_outs return main_outs
pushout_equilibrium = gof.EquilibriumDB() pushout_equilibrium = gof.EquilibriumDB()
pushout_equilibrium.register("ifelse_lift", pushout_equilibrium.register("ifelse_lift",
opt.in2out(cond_lift_single_if, opt.in2out(cond_lift_single_if,
ignore_newtrees = True), ignore_newtrees=True),
'fast_run', 'ifelse') 'fast_run', 'ifelse')
pushout_equilibrium.register("ifelse_merge_ifs", pushout_equilibrium.register("ifelse_merge_ifs",
opt.in2out(cond_merge_random_op, opt.in2out(cond_merge_random_op,
ignore_newtrees = True), ignore_newtrees=True),
'fast_run', 'ifelse') 'fast_run', 'ifelse')
...@@ -615,7 +600,7 @@ pushout_equilibrium.register("ifelse_merge_nodes", ...@@ -615,7 +600,7 @@ pushout_equilibrium.register("ifelse_merge_nodes",
pushout_equilibrium.register("ifelse_remove_identical_inside", pushout_equilibrium.register("ifelse_remove_identical_inside",
opt.in2out(cond_remove_identical, opt.in2out(cond_remove_identical,
ignore_newtrees = True), ignore_newtrees=True),
'fast_run', 'ifelse') 'fast_run', 'ifelse')
pushout_equilibrium.register('ifelse_sameCondTrue_inside', pushout_equilibrium.register('ifelse_sameCondTrue_inside',
...@@ -644,8 +629,6 @@ ifelse_seqopt.register('ifelse_sameCondTrue', ...@@ -644,8 +629,6 @@ ifelse_seqopt.register('ifelse_sameCondTrue',
3, 'fast_run', 'ifelse') 3, 'fast_run', 'ifelse')
ifelse_seqopt.register('ifelse_sameCondFalse', ifelse_seqopt.register('ifelse_sameCondFalse',
opt.in2out(cond_merge_ifs_false, opt.in2out(cond_merge_ifs_false,
ignore_newtrees=True), ignore_newtrees=True),
...@@ -654,6 +637,5 @@ ifelse_seqopt.register('ifelse_sameCondFalse', ...@@ -654,6 +637,5 @@ ifelse_seqopt.register('ifelse_sameCondFalse',
ifelse_seqopt.register('ifelse_removeIdenetical', ifelse_seqopt.register('ifelse_removeIdenetical',
opt.in2out(cond_remove_identical, opt.in2out(cond_remove_identical,
ignore_newtrees = True), ignore_newtrees=True),
7, 'fast_run', 'ifelse') 7, 'fast_run', 'ifelse')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论