提交 ce690a69 authored 作者: Razvan Pascanu's avatar Razvan Pascanu 提交者: David Warde-Farley

Made the code pep8 compatible

上级 8c08c76e
"""
IfElse is an Op that works with the LazyLinker to support conditional graph evaluation.
IfElse is an Op that works with the LazyLinker to support conditional graph
evaluation.
:TODO: Add text to library documentation describing the IfElse Op.
"""
__docformat__ = 'restructedtext en'
__authors__ = ( "Razvan Pascanu "
"James Bergstra "
"Dumitru Erhan ")
__authors__ = ("Razvan Pascanu "
"James Bergstra "
"Dumitru Erhan ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
......@@ -27,8 +28,6 @@ from scan_module.scan_utils import clone
_logger = logging.getLogger('theano.ifelse')
class IfElse(PureOp):
"""
Op that works with CVM/VM to support conditional graph evaluation. Note,
......@@ -47,20 +46,19 @@ class IfElse(PureOp):
False branch before picking one.
"""
def __init__(self, n_outs, as_view=False, gpu = False, name = None):
def __init__(self, n_outs, as_view=False, gpu=False, name=None):
if as_view:
# check destroyhandler and others to ensure that a view_map with
# multiple inputs can work
view_map = {}
for idx in xrange(n_outs):
view_map[idx] = [idx+1]
view_map[idx] = [idx + 1]
self.view_map = view_map
#raise NotImplementedError('Cond must copy for now')
self.as_view = as_view
self.gpu = gpu
self.gpu = gpu
self.n_outs = n_outs
self.name = name
self.name = name
def __eq__(self, other):
if not type(self) == type(other):
......@@ -75,7 +73,7 @@ class IfElse(PureOp):
def __hash__(self):
rval = ( hash(type(self)) ^
rval = (hash(type(self)) ^
hash(self.as_view) ^
hash(self.gpu) ^
hash(self.n_outs))
......@@ -83,7 +81,7 @@ class IfElse(PureOp):
return rval
def __str__(self):
name ='if{%s'%str(self.name)
name = 'if{%s' % str(self.name)
if self.as_view:
name += ',inplace'
if self.gpu:
......@@ -91,7 +89,6 @@ class IfElse(PureOp):
name += '}'
return name
def infer_shape(self, node, inputs_shapes):
# By construction, corresponding then/else pairs have the same number
# of dimensions
......@@ -103,7 +100,7 @@ class IfElse(PureOp):
new_ts_inputs += list(ts_shape)
else:
# It can be None for generic objects
return [None]*self.n_outs
return [None] * self.n_outs
new_fs_inputs = []
for fs_shape in fs_shapes:
......@@ -111,18 +108,18 @@ class IfElse(PureOp):
new_fs_inputs += list(fs_shape)
else:
# It can be None for generic objects
return [None]*self.n_outs
return [None] * self.n_outs
assert len(new_ts_inputs) == len(new_fs_inputs)
if len(new_ts_inputs + new_fs_inputs) > 0:
new_ifelse = IfElse(
n_outs = len(new_ts_inputs),
as_view = False,
gpu = False,
name='shape_'+str(self.name))
n_outs=len(new_ts_inputs),
as_view=False,
gpu=False,
name='shape_' + str(self.name))
new_outs = new_ifelse.make_node(node.inputs[0],
*(new_ts_inputs+new_fs_inputs)).outputs
*(new_ts_inputs + new_fs_inputs)).outputs
else:
new_outs = []
......@@ -138,12 +135,8 @@ class IfElse(PureOp):
out_shapes += [tuple(current_shape)]
return out_shapes
def R_op(self, inputs, eval_points):
return self.make_node(inputs[0],*eval_points[1:]).outputs
return self.make_node(inputs[0], *eval_points[1:]).outputs
def make_node(self, c, *args):
if not self.gpu:
......@@ -160,15 +153,15 @@ class IfElse(PureOp):
ts = args[:self.n_outs]
fs = args[self.n_outs:]
for t,f in zip(ts, fs):
for t, f in zip(ts, fs):
if t.type != f.type:
raise TypeError(('IfElse requires same types for true and '
'false return values'), t, f, t.type, f.type)
if c.ndim >0:
if c.ndim > 0:
raise TypeError(('Condition given to the op has to be a scalar '
'with 0 standing for False, anything else for True'))
return Apply(self, [c]+list(args), [t.type() for t in ts])
'with 0 standing for False, anything else '
'for True'))
return Apply(self, [c] + list(args), [t.type() for t in ts])
def grad(self, ins, grads):
ts = ins[1:][:self.n_outs]
......@@ -179,35 +172,31 @@ class IfElse(PureOp):
else:
nw_name_t = None
nw_name_f = None
if_true_op = IfElse(n_outs = self.n_outs,
as_view = self.as_view,
gpu = self.gpu,
name = nw_name_t)
if_true_op = IfElse(n_outs=self.n_outs,
as_view=self.as_view,
gpu=self.gpu,
name=nw_name_t)
if_false_op = IfElse(n_outs = self.n_outs,
as_view = self.as_view,
gpu = self.gpu,
name = nw_name_f)
if_false_op = IfElse(n_outs=self.n_outs,
as_view=self.as_view,
gpu=self.gpu,
name=nw_name_f)
if_true = ([ins[0]]+ grads+ [theano.tensor.zeros_like(t)
if_true = ([ins[0]] + grads + [theano.tensor.zeros_like(t)
for t in ts])
if_false = ([ins[0]] + [theano.tensor.zeros_like(f)
for f in fs] + grads)
return ([None]+
return ([None] +
if_true_op.make_node(*if_true).outputs +
if_false_op.make_node(*if_false).outputs )
if_false_op.make_node(*if_false).outputs)
def make_thunk(self, node, storage_map, compute_map, no_recycling):
outtypes = [ out.type for out in node.outputs]
cond = node.inputs[0]
outtypes = [out.type for out in node.outputs]
cond = node.inputs[0]
ts = node.inputs[1:][:self.n_outs]
fs = node.inputs[1:][self.n_outs:]
outputs = node.outputs
outputs = node.outputs
def thunk():
if not compute_map[cond][0]:
......@@ -215,7 +204,7 @@ class IfElse(PureOp):
else:
truthval = storage_map[cond][0]
if truthval != 0:
ls = [idx+1 for idx in xrange(self.n_outs)
ls = [idx + 1 for idx in xrange(self.n_outs)
if not compute_map[ts[idx]][0]]
if len(ls) > 0:
return ls
......@@ -232,7 +221,7 @@ class IfElse(PureOp):
storage_map[out][0] = oval
return []
else:
ls = [1+idx+self.n_outs for idx in xrange(self.n_outs)
ls = [1 + idx + self.n_outs for idx in xrange(self.n_outs)
if not compute_map[fs[idx]][0]]
if len(ls) > 0:
return ls
......@@ -249,12 +238,12 @@ class IfElse(PureOp):
return []
thunk.lazy = True
thunk.inputs = [ storage_map[v] for v in node.inputs]
thunk.outputs = [ storage_map[v] for v in node.outputs]
thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs]
return thunk
def ifelse( cond, true_branch, false_branch, name = None):
def ifelse(cond, true_branch, false_branch, name=None):
"""
This function corresponds to a if statement, returning inputs in the
``true_branch`` if ``cond`` evaluates to True or inputs in the
......@@ -293,34 +282,33 @@ def ifelse( cond, true_branch, false_branch, name = None):
false_branch = [false_branch]
if len(true_branch) != len(false_branch):
raise ValueError(( 'The number of values on the `then` branch'+
' should have the same number of variables as '+
'the `else` branch : (variables on `then` '+
'%d'%len(true_branch)+ ', variables on `else` '+
'%d'%len(false_branch)+')'))
raise ValueError(('The number of values on the `then` branch'
' should have the same number of variables as '
'the `else` branch : (variables on `then` '
'%d' % len(true_branch) + ', variables on `else` '
'%d' % len(false_branch) + ')'))
new_ifelse = IfElse(n_outs = len(true_branch),
new_ifelse = IfElse(n_outs=len(true_branch),
as_view=False,
gpu = False,
name = name)
gpu=False,
name=name)
ins = [cond] + list(true_branch) + list(false_branch)
rval = new_ifelse.make_node(*ins).outputs
if type(rval) in (list,tuple) and len(rval) == 1:
if type(rval) in (list, tuple) and len(rval) == 1:
return rval[0]
else:
return rval
@gof.local_optimizer([None])
def cond_make_inplace(node):
op = node.op
if isinstance(op, IfElse) and not op.as_view :
return IfElse(n_outs = op.n_outs,
as_view = True,
gpu = op.gpu,
name = op.name ).make_node(*node.inputs).outputs
if isinstance(op, IfElse) and not op.as_view:
return IfElse(n_outs=op.n_outs,
as_view=True,
gpu=op.gpu,
name=op.name).make_node(*node.inputs).outputs
return False
optdb.register('cond_make_inplace', opt.in2out(cond_make_inplace,
......@@ -330,29 +318,30 @@ ifelse_equilibrium = gof.EquilibriumDB()
ifelse_seqopt = gof.SequenceDB()
ifelse_equilibrium.register('seq_ifelse', ifelse_seqopt, 'fast_run',
'ifelse')
optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, .5, 'fast_run', 'ifelse')
optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, .5, 'fast_run',
'ifelse')
@gof.local_optimizer([None])
def cond_merge_ifs_true(node):
op = node.op
op = node.op
if not isinstance(op, IfElse):
return False
t_ins = node.inputs[1:][:op.n_outs]
replace = {}
for idx,tval in enumerate(t_ins):
for idx, tval in enumerate(t_ins):
if (tval.owner and isinstance(tval.owner.op, IfElse) and
tval.owner.inputs[0] == node.inputs[0]):
ins_op = tval.owner.op
ins_t = tval.owner.inputs[1:][:ins_op.n_outs]
replace[idx+1] = ins_t[tval.owner.outputs.index(tval)]
ins_t = tval.owner.inputs[1:][:ins_op.n_outs]
replace[idx + 1] = ins_t[tval.owner.outputs.index(tval)]
if len(replace.items()) == 0:
return False
old_ins = list(node.inputs)
for pos,var in replace.items():
for pos, var in replace.items():
old_ins[pos] = var
return op.make_node(*old_ins).outputs
......@@ -365,24 +354,26 @@ def cond_merge_ifs_false(node):
f_ins = node.inputs[1:][op.n_outs:]
replace = {}
for idx,fval in enumerate(f_ins):
for idx, fval in enumerate(f_ins):
if (fval.owner and isinstance(fval.owner.op, IfElse) and
fval.owner.inputs[0] == node.inputs[0]):
ins_op = fval.owner.op
ins_t = fval.owner.inputs[1:][ins_op.n_outs:]
replace[idx+1+op.n_outs] = ins_t[fval.owner.outputs.index(fval)]
ins_t = fval.owner.inputs[1:][ins_op.n_outs:]
replace[idx + 1 + op.n_outs] = \
ins_t[fval.owner.outputs.index(fval)]
if len(replace.items()) == 0:
return False
old_ins = list(node.inputs)
for pos,var in replace.items():
for pos, var in replace.items():
old_ins[pos] = var
return op.make_node(*old_ins).outputs
class CondMerge(gof.Optimizer):
""" Graph Optimizer that merges different cond ops """
def add_requirements(self,env):
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate())
def apply(self, env):
......@@ -400,20 +391,20 @@ class CondMerge(gof.Optimizer):
pl_ts = proposal.inputs[1:][:proposal.op.n_outs]
pl_fs = proposal.inputs[1:][proposal.op.n_outs:]
new_ins = ([merging_node.inputs[0]] +
mn_ts + pl_ts + mn_fs + pl_fs )
mn_ts + pl_ts + mn_fs + pl_fs)
mn_name = '?'
if merging_node.op.name:
mn_name = merging_node.op.name
pl_name = '?'
mn_n_ts = len(mn_ts)
mn_n_fs = len(mn_fs)
mn_n_ts = len(mn_ts)
mn_n_fs = len(mn_fs)
if proposal.op.name:
pl_name = proposal.op.name
new_ifelse = IfElse(
n_outs = len(mn_ts+pl_ts),
n_outs=len(mn_ts + pl_ts),
as_view=False,
gpu = False,
name = mn_name+'&'+pl_name)
gpu=False,
name=mn_name + '&' + pl_name)
print 'here'
new_outs = new_ifelse.make_node(*new_ins).outputs
new_outs = [clone(x) for x in new_outs]
......@@ -443,7 +434,7 @@ def cond_remove_identical(node):
out_map = {}
for idx in xrange(len(node.outputs)):
if idx not in out_map:
for jdx in xrange(idx+1,len(node.outputs)):
for jdx in xrange(idx + 1, len(node.outputs)):
if (ts[idx] == ts[jdx] and
fs[idx] == fs[jdx] and
jdx not in out_map):
......@@ -454,7 +445,7 @@ def cond_remove_identical(node):
nw_ts = []
nw_fs = []
inv_map = {}
inv_map = {}
pos = 0
for idx in xrange(len(node.outputs)):
if idx not in out_map:
......@@ -463,12 +454,10 @@ def cond_remove_identical(node):
nw_ts.append(ts[idx])
nw_fs.append(fs[idx])
new_ifelse = IfElse(n_outs = len(nw_ts),
as_view = op.as_view,
gpu = op.gpu,
name = op.name)
new_ifelse = IfElse(n_outs=len(nw_ts),
as_view=op.as_view,
gpu=op.gpu,
name=op.name)
new_ins = [node.inputs[0]] + nw_ts + nw_fs
new_outs = new_ifelse.make_node(*new_ins).outputs
......@@ -476,14 +465,12 @@ def cond_remove_identical(node):
rval = []
for idx in xrange(len(node.outputs)):
if idx in out_map.keys():
rval += [new_outs[inv_map[out_map[idx]]] ]
rval += [new_outs[inv_map[out_map[idx]]]]
else:
rval += [new_outs[inv_map[idx]]]
return rval
acceptable_ops = (theano.tensor.basic.Dot,
theano.tensor.basic.Reshape,
theano.tensor.basic.Shape,
......@@ -494,7 +481,7 @@ acceptable_ops = (theano.tensor.basic.Dot,
theano.tensor.basic.Rebroadcast,
theano.tensor.basic.Alloc,
theano.tensor.elemwise.Elemwise,
theano.tensor.elemwise.DimShuffle )
theano.tensor.elemwise.DimShuffle)
@gof.local_optimizer([None])
......@@ -504,7 +491,8 @@ def cond_lift_single_if(main_node):
all_inp_nodes = set()
for inp in main_node.inputs:
all_inp_nodes.add(inp.owner)
ifnodes = [ x for x in list(all_inp_nodes) if x and isinstance(x.op, IfElse)]
ifnodes = [x for x in list(all_inp_nodes)
if x and isinstance(x.op, IfElse)]
# if we have multiple ifs as inputs .. it all becomes quite complicated
# :)
if len(ifnodes) != 1:
......@@ -528,7 +516,7 @@ def cond_lift_single_if(main_node):
else:
true_ins.append(x)
false_ins.append(x)
true_eval = mop.make_node(*true_ins).outputs
true_eval = mop.make_node(*true_ins).outputs
false_eval = mop.make_node(*false_ins).outputs
#true_eval = clone(outs, replace = dict(zip(node.outputs, ts)))
#false_eval = clone(outs, replace = dict(zip(node.outputs, fs)))
......@@ -539,7 +527,6 @@ def cond_lift_single_if(main_node):
return nw_outs
@gof.local_optimizer([None])
def cond_merge_random_op(main_node):
if isinstance(main_node.op, IfElse):
......@@ -548,7 +535,8 @@ def cond_merge_random_op(main_node):
all_inp_nodes = set()
for inp in main_node.inputs:
all_inp_nodes.add(inp.owner)
cond_nodes = [ x for x in list(all_inp_nodes) if x and isinstance(x.op, IfElse)]
cond_nodes = [x for x in list(all_inp_nodes)
if x and isinstance(x.op, IfElse)]
if len(cond_nodes) < 2:
return False
......@@ -564,20 +552,20 @@ def cond_merge_random_op(main_node):
pl_ts = proposal.inputs[1:][:proposal.op.n_outs]
pl_fs = proposal.inputs[1:][proposal.op.n_outs:]
new_ins = ([merging_node.inputs[0]] +
mn_ts + pl_ts + mn_fs + pl_fs )
mn_ts + pl_ts + mn_fs + pl_fs)
mn_name = '?'
if merging_node.op.name:
mn_name = merging_node.op.name
pl_name = '?'
mn_n_ts = len(mn_ts)
mn_n_fs = len(mn_fs)
mn_n_ts = len(mn_ts)
mn_n_fs = len(mn_fs)
if proposal.op.name:
pl_name = proposal.op.name
new_ifelse = IfElse(
n_outs = len(mn_ts+pl_ts),
n_outs=len(mn_ts + pl_ts),
as_view=False,
gpu = False,
name = mn_name+'&'+pl_name)
gpu=False,
name=mn_name + '&' + pl_name)
new_outs = new_ifelse.make_node(*new_ins).outputs
old_outs = []
if type(merging_node.outputs) not in (list, tuple):
......@@ -593,19 +581,16 @@ def cond_merge_random_op(main_node):
return main_outs
pushout_equilibrium = gof.EquilibriumDB()
pushout_equilibrium.register("ifelse_lift",
opt.in2out(cond_lift_single_if,
ignore_newtrees = True),
ignore_newtrees=True),
'fast_run', 'ifelse')
pushout_equilibrium.register("ifelse_merge_ifs",
opt.in2out(cond_merge_random_op,
ignore_newtrees = True),
ignore_newtrees=True),
'fast_run', 'ifelse')
......@@ -615,7 +600,7 @@ pushout_equilibrium.register("ifelse_merge_nodes",
pushout_equilibrium.register("ifelse_remove_identical_inside",
opt.in2out(cond_remove_identical,
ignore_newtrees = True),
ignore_newtrees=True),
'fast_run', 'ifelse')
pushout_equilibrium.register('ifelse_sameCondTrue_inside',
......@@ -644,8 +629,6 @@ ifelse_seqopt.register('ifelse_sameCondTrue',
3, 'fast_run', 'ifelse')
ifelse_seqopt.register('ifelse_sameCondFalse',
opt.in2out(cond_merge_ifs_false,
ignore_newtrees=True),
......@@ -654,6 +637,5 @@ ifelse_seqopt.register('ifelse_sameCondFalse',
ifelse_seqopt.register('ifelse_removeIdenetical',
opt.in2out(cond_remove_identical,
ignore_newtrees = True),
ignore_newtrees=True),
7, 'fast_run', 'ifelse')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论