提交 1e8b0b00 authored 作者: Frederic's avatar Frederic

pep8

上级 e2e178e1
......@@ -4,11 +4,11 @@ This module provides utility functions for the Scan Op
See scan.py for details on scan
"""
__docformat__ = 'restructedtext en'
__authors__ = ( "Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron")
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
......@@ -16,7 +16,6 @@ import copy
import logging
import numpy
from theano import config
from theano.compile.pfunc import rebuild_collect_shared
from theano import gof
from theano import tensor, scalar
......@@ -30,7 +29,8 @@ import theano
# Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_utils')
def safe_new(x, tag = ''):
def safe_new(x, tag=''):
"""
Internal function that constructs a new variable from x with the same
type, but with a different name ( old name + tag). This function is used
......@@ -81,7 +81,7 @@ class until(object):
assert self.condition.ndim == 0
def traverse(out, x,x_copy, d):
def traverse(out, x, x_copy, d):
''' Function used by scan to parse the tree and figure out which nodes
it needs to replace. There are two options :
1) x and x_copy or on host, then you would replace x with x_copy
......@@ -111,10 +111,10 @@ def traverse(out, x,x_copy, d):
def hash_listsDictsTuples(x):
hash_value = 0
if isinstance(x, dict):
for k,v in x.iteritems():
for k, v in x.iteritems():
hash_value ^= hash_listsDictsTuples(k)
hash_value ^= hash_listsDictsTuples(v)
elif isinstance(x, (list,tuple)):
elif isinstance(x, (list, tuple)):
for v in x:
hash_value ^= hash_listsDictsTuples(v)
else:
......@@ -122,10 +122,10 @@ def hash_listsDictsTuples(x):
return hash_value
def clone( output
, replace = None
, strict = True
, copy_inputs = True):
def clone(output,
replace=None,
strict=True,
copy_inputs=True):
"""
Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding
......@@ -140,17 +140,16 @@ def clone( output
replaced by what
"""
inps, outs, other_stuff = rebuild_collect_shared( output
, []
, replace
, []
, strict
, copy_inputs
)
inps, outs, other_stuff = rebuild_collect_shared(output,
[],
replace,
[],
strict,
copy_inputs
)
return outs
def get_updates_and_outputs(ls):
"""
This function tries to recognize the updates dictionary, the
......@@ -160,7 +159,7 @@ def get_updates_and_outputs(ls):
"""
def is_outputs(elem):
if (isinstance(elem, (list,tuple)) and
if (isinstance(elem, (list, tuple)) and
all([isinstance(x, theano.Variable) for x in elem])):
return True
if isinstance(elem, theano.Variable):
......@@ -172,7 +171,7 @@ def get_updates_and_outputs(ls):
return True
# Dictionaries can be given as lists of tuples
if (isinstance(elem, (list, tuple)) and
all([isinstance(x, (list,tuple)) and len(x) ==2
all([isinstance(x, (list, tuple)) and len(x) == 2
for x in elem])):
return True
return False
......@@ -204,13 +203,13 @@ def get_updates_and_outputs(ls):
if is_updates(ls[1]):
return (None, _list(ls[0]), dict(ls[1]))
elif is_condition(ls[1]):
return ( ls[1].condition, _list(ls[0]), {})
return (ls[1].condition, _list(ls[0]), {})
else:
raise ValueError(error_msg)
elif is_updates(ls[0]):
if is_outputs(ls[1]):
_logger.warning(deprication_msg)
return ( None, _list(ls[1]), dict(ls[0]) )
return (None, _list(ls[1]), dict(ls[0]))
elif is_condition(ls[1]):
return (ls[1].condition, [], dict(ls[0]))
else:
......@@ -251,7 +250,7 @@ def isNaN_or_Inf_or_None(x):
isStr = False
if not isNaN and not isInf:
try:
val = get_constant_value(x)
val = get_constant_value(x)
isInf = numpy.isinf(val)
isNaN = numpy.isnan(val)
except Exception:
......@@ -264,7 +263,7 @@ def isNaN_or_Inf_or_None(x):
return isNone or isNaN or isInf or isStr
def expand( tensor_var, size):
def expand(tensor_var, size):
'''
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
by adding 0s at the end of the tensor.
......@@ -272,13 +271,14 @@ def expand( tensor_var, size):
# Corner case that I might use in an optimization
if size == 0:
return tensor_var
shapes = [ tensor_var.shape[x] for x in xrange(tensor_var.ndim) ]
zeros_shape = [size+shapes[0]] + shapes[1:]
empty = tensor.zeros( zeros_shape
, dtype = tensor_var.dtype)
shapes = [tensor_var.shape[x] for x in xrange(tensor_var.ndim)]
zeros_shape = [size + shapes[0]] + shapes[1:]
empty = tensor.zeros(zeros_shape,
dtype=tensor_var.dtype)
return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
'''
Checks if to theano graphs represent the same computations (with
equivalence of inputs defined by map). Inputs are always assumed
......@@ -289,8 +289,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if in_ys is None:
in_ys = []
for x,y in zip(xs,ys):
for x, y in zip(xs, ys):
if x.owner and not y.owner:
return False
if y.owner and not x.owner:
......@@ -300,7 +299,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
return False
if len(in_xs) != len(in_ys):
return False
for _x,_y in zip(in_xs, in_ys):
for _x, _y in zip(in_xs, in_ys):
if _x.type != _y.type:
return False
......@@ -308,17 +307,17 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
nds_y = gof.graph.io_toposort(in_ys, ys)
if len(nds_x) != len(nds_y):
return False
common = set(zip(in_xs,in_ys))
common = set(zip(in_xs, in_ys))
n_nodes = len(nds_x)
cont = True
idx = 0
for dx,dy in zip(xs,ys):
for dx, dy in zip(xs, ys):
if not dx.owner or not dy.owner:
if dy.owner or dx.owner:
return False
elif (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)):
if not ( numpy.all(dx.data == dy.data) and
if not (numpy.all(dx.data == dy.data) and
dx.dtype == dy.dtype and
dx.data.shape == dy.data.shape):
return False
......@@ -329,7 +328,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if dx.type != dy.type:
return False
else:
if (dx,dy) not in common:
if (dx, dy) not in common:
return False
while cont and idx < n_nodes:
......@@ -342,9 +341,9 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
elif len(nd_x.outputs) != len(nd_y.outputs):
cont = False
else:
for dx,dy in zip(nd_x.inputs, nd_y.inputs):
if (dx,dy) not in common:
if strict and dx!= dy:
for dx, dy in zip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common:
if strict and dx != dy:
if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and
......@@ -359,16 +358,13 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
cont = cont and (dx.type == dy.type)
if cont:
for dx,dy in zip(nd_x.outputs, nd_y.outputs):
common.add((dx,dy))
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
common.add((dx, dy))
idx += 1
return cont
def infer_shape(outs, inputs, input_shapes):
'''
Compute the shape of the outputs given the shape of the inputs
......@@ -416,6 +412,7 @@ def infer_shape(outs, inputs, input_shapes):
ret.append(shape_feature.shape_of[o])
return ret
class Validator(object):
def __init__(self, valid=[], invalid=[], valid_equivalent={}):
'''
......@@ -494,35 +491,35 @@ def scan_can_remove_outs(op, out_idxs):
the first one with the indices of outs that can be removed, the second
with the outputs that can not be removed.
'''
non_removable = [ o for i,o in enumerate(op.outputs) if i not in
non_removable = [o for i, o in enumerate(op.outputs) if i not in
out_idxs]
required_inputs = gof.graph.inputs(non_removable)
out_ins = []
offset = op.n_seqs
offset = op.n_seqs
lim = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot
for idx in range(lim):
n_ins = len(op.info['tap_array'][idx])
out_ins += [op.inputs[offset:offset+n_ins]]
offset += n_ins
out_ins += [ [] for k in xrange(op.n_nit_sot) ]
out_ins += [ [op.inputs[offset+k]] for k in xrange(op.n_shared_outs)]
n_ins = len(op.info['tap_array'][idx])
out_ins += [op.inputs[offset:offset + n_ins]]
offset += n_ins
out_ins += [[] for k in xrange(op.n_nit_sot)]
out_ins += [[op.inputs[offset + k]] for k in xrange(op.n_shared_outs)]
added = True
out_idxs_mask = [1 for idx in out_idxs]
while added:
added = False
for pos,idx in enumerate(out_idxs):
if ( out_idxs_mask[pos] and
numpy.any([x in required_inputs for x in out_ins[idx]]) ):
for pos, idx in enumerate(out_idxs):
if (out_idxs_mask[pos] and
numpy.any([x in required_inputs for x in out_ins[idx]])):
# This output is required ..
out_idxs_mask[pos] = 0
required_inputs += gof.graph.inputs([op.outputs[idx]])
added = True
required_outs = [x for i,x in enumerate(out_idxs)
required_outs = [x for i, x in enumerate(out_idxs)
if out_idxs_mask[i] == 0]
not_required = [x for i,x in enumerate(out_idxs) if out_idxs_mask[i]==1]
not_required = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 1]
return (required_outs, not_required)
......@@ -537,107 +534,107 @@ def compress_outs(op, not_required, inputs):
node inputs, and changing the dictionary.
'''
info = {}
info['tap_array'] = []
info['n_seqs'] = op.info['n_seqs']
info['n_mit_mot'] = 0
info['n_mit_mot_outs'] = 0
info['tap_array'] = []
info['n_seqs'] = op.info['n_seqs']
info['n_mit_mot'] = 0
info['n_mit_mot_outs'] = 0
info['mit_mot_out_slices'] = []
info['n_mit_sot'] = 0
info['n_sit_sot'] = 0
info['n_shared_outs'] = 0
info['n_nit_sot'] = 0
info['truncate_gradient'] = op.info['truncate_gradient']
info['name'] = op.info['name']
info['inplace'] = op.info['inplace']
info['gpu'] = op.info['gpu']
info['mode'] = op.info['mode']
info['as_while'] = op.info['as_while']
info['profile'] = op.info['profile']
op_inputs = op.inputs[:op.n_seqs]
op_outputs = []
info['n_mit_sot'] = 0
info['n_sit_sot'] = 0
info['n_shared_outs'] = 0
info['n_nit_sot'] = 0
info['truncate_gradient'] = op.info['truncate_gradient']
info['name'] = op.info['name']
info['inplace'] = op.info['inplace']
info['gpu'] = op.info['gpu']
info['mode'] = op.info['mode']
info['as_while'] = op.info['as_while']
info['profile'] = op.info['profile']
op_inputs = op.inputs[:op.n_seqs]
op_outputs = []
node_inputs = inputs[:op.n_seqs + 1]
map_old_new = {}
offset = 0
ni_offset = op.n_seqs+1
i_offset = op.n_seqs
o_offset = 0
curr_pos = 0
ni_offset = op.n_seqs + 1
i_offset = op.n_seqs
o_offset = 0
curr_pos = 0
for idx in xrange(op.info['n_mit_mot']):
if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info['n_mit_mot'] += 1
info['tap_array'] += [op.tap_array[offset+idx]]
info['mit_mot_out_slices'] += [op.mit_mot_out_slices[offset+idx]]
info['tap_array'] += [op.tap_array[offset + idx]]
info['mit_mot_out_slices'] += [op.mit_mot_out_slices[offset + idx]]
# input taps
for jdx in op.tap_array[offset+idx]:
for jdx in op.tap_array[offset + idx]:
op_inputs += [op.inputs[i_offset]]
i_offset += 1
# output taps
for jdx in op.mit_mot_out_slices[offset+idx]:
for jdx in op.mit_mot_out_slices[offset + idx]:
op_outputs += [op.outputs[o_offset]]
o_offset += 1
# node inputs
node_inputs += [inputs[ni_offset+idx]]
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset += len(op.mit_mot_out_slices[offset+idx])
i_offset += len(op.tap_array[offset+idx])
o_offset += len(op.mit_mot_out_slices[offset + idx])
i_offset += len(op.tap_array[offset + idx])
info['n_mit_mot_outs'] = len(op_outputs)
offset += op.n_mit_mot
offset += op.n_mit_mot
ni_offset += op.n_mit_mot
for idx in xrange(op.info['n_mit_sot']):
if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info['n_mit_sot'] += 1
info['tap_array'] += [op.tap_array[offset+idx]]
info['tap_array'] += [op.tap_array[offset + idx]]
#input taps
for jdx in op.tap_array[offset+idx]:
for jdx in op.tap_array[offset + idx]:
op_inputs += [op.inputs[i_offset]]
i_offset += 1
#output taps
op_outputs += [op.outputs[o_offset]]
o_offset+=1
o_offset += 1
#node inputs
node_inputs += [inputs[ni_offset+idx]]
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset+=1
i_offset+=len(op.tap_array[offset+idx])
o_offset += 1
i_offset += len(op.tap_array[offset + idx])
offset += op.n_mit_sot
offset += op.n_mit_sot
ni_offset += op.n_mit_sot
for idx in xrange(op.info['n_sit_sot']):
if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info['n_sit_sot'] += 1
info['tap_array'] += [op.tap_array[offset+idx]]
info['tap_array'] += [op.tap_array[offset + idx]]
#input taps
op_inputs += [op.inputs[i_offset]]
i_offset += 1
#output taps
op_outputs += [op.outputs[o_offset]]
o_offset+=1
o_offset += 1
#node inputs
node_inputs += [inputs[ni_offset+idx]]
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset+=1
i_offset+=1
o_offset += 1
i_offset += 1
offset += op.n_sit_sot
offset += op.n_sit_sot
ni_offset += op.n_sit_sot
nit_sot_ins = []
for idx in xrange(op.info['n_nit_sot']):
if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info['n_nit_sot'] += 1
op_outputs += [op.outputs[o_offset]]
o_offset+=1
nit_sot_ins += [inputs[ni_offset+idx+op.n_shared_outs]]
o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]]
else:
o_offset += 1
......@@ -645,14 +642,14 @@ def compress_outs(op, not_required, inputs):
shared_ins = []
for idx in xrange(op.info['n_shared_outs']):
if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info['n_shared_outs'] += 1
op_outputs += [ op.outputs[o_offset]]
o_offset +=1
op_inputs += [ op.inputs[i_offset]]
op_outputs += [op.outputs[o_offset]]
o_offset += 1
op_inputs += [op.inputs[i_offset]]
i_offset += 1
shared_ins += [inputs[ni_offset+idx]]
shared_ins += [inputs[ni_offset + idx]]
else:
o_offset += 1
i_offset += 1
......@@ -660,14 +657,15 @@ def compress_outs(op, not_required, inputs):
node_inputs += nit_sot_ins
# other stuff
op_inputs += op.inputs[i_offset:]
node_inputs += inputs[ni_offset+op.n_shared_outs+op.n_nit_sot:]
node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot:]
if op.as_while:
op_outputs += [op.outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs)-1
map_old_new[o_offset] = len(op_outputs) - 1
#map_old_new[len(op_outputs)-1] = o_offset
return (op_inputs, op_outputs, info, node_inputs, map_old_new)
def find_up(l_node, f_node):
r"""
Goes up in the graph and returns True if a node in nodes is found.
......@@ -676,11 +674,12 @@ def find_up(l_node, f_node):
l_outs = l_node.outputs
else:
l_outs = l_node
l_ins = gof.graph.inputs(l_outs)
l_ins = gof.graph.inputs(l_outs)
nodes = gof.graph.io_toposort(l_ins, l_outs)
return f_node in nodes
def reconstruct_graph(inputs, outputs, tag = None):
def reconstruct_graph(inputs, outputs, tag=None):
"""
Different interface to clone, that allows you to pass inputs.
Compared to clone, this method always replaces the inputs with
......@@ -689,7 +688,7 @@ def reconstruct_graph(inputs, outputs, tag = None):
"""
if tag is None:
tag = ''
nw_inputs = [safe_new(x,tag) for x in inputs]
nw_inputs = [safe_new(x, tag) for x in inputs]
givens = {}
for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x
......@@ -698,9 +697,10 @@ def reconstruct_graph(inputs, outputs, tag = None):
if isinstance(inp, theano.Constant):
givens[inp] = inp.clone()
nw_outputs = clone( outputs, replace=givens)
nw_outputs = clone(outputs, replace=givens)
return (nw_inputs, nw_outputs)
class scan_args(object):
"""Parses the inputs and outputs of scan in an easy to manipulate format"""
def __init__(self, outer_inputs, outer_outputs,
......@@ -712,14 +712,14 @@ class scan_args(object):
inner_outputs = rval[1][:-1]
else:
inner_outputs = rval[1]
inner_inputs = rval[0]
inner_inputs = rval[0]
p = 1
q = 0
n_seqs = info['n_seqs']
self.outer_in_seqs = outer_inputs[p:p+n_seqs]
self.inner_in_seqs = inner_inputs[q:q+n_seqs]
self.outer_in_seqs = outer_inputs[p:p + n_seqs]
self.inner_in_seqs = inner_inputs[q:q + n_seqs]
p += n_seqs
q += n_seqs
......@@ -727,46 +727,47 @@ class scan_args(object):
n_mit_sot = info['n_mit_sot']
self.mit_mot_in_slices = info['tap_array'][:n_mit_mot]
self.mit_sot_in_slices = info['tap_array'][n_mit_mot:n_mit_mot+n_mit_sot]
self.mit_sot_in_slices = info['tap_array'][
n_mit_mot:n_mit_mot + n_mit_sot]
n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)
iimm = inner_inputs[q:q+n_mit_mot_ins]
iimm = inner_inputs[q:q + n_mit_mot_ins]
self.inner_in_mit_mot = []
qq = 0
for sl in self.mit_mot_in_slices:
self.inner_in_mit_mot.append(iimm[qq:qq+len(sl)])
self.inner_in_mit_mot.append(iimm[qq:qq + len(sl)])
qq += len(sl)
q += n_mit_mot_ins
iims = inner_inputs[q:q+n_mit_sot_ins]
iims = inner_inputs[q:q + n_mit_sot_ins]
self.inner_in_mit_sot = []
qq = 0
for sl in self.mit_sot_in_slices:
self.inner_in_mit_sot.append(iims[qq:qq+len(sl)])
self.inner_in_mit_sot.append(iims[qq:qq + len(sl)])
qq += len(sl)
q += n_mit_sot_ins
self.outer_in_mit_mot = outer_inputs[p:p+n_mit_mot]
self.outer_in_mit_mot = outer_inputs[p:p + n_mit_mot]
p += n_mit_mot
self.outer_in_mit_sot = outer_inputs[p:p+n_mit_sot]
self.outer_in_mit_sot = outer_inputs[p:p + n_mit_sot]
p += n_mit_sot
n_sit_sot = info['n_sit_sot']
self.outer_in_sit_sot = outer_inputs[p:p+n_sit_sot]
self.inner_in_sit_sot = inner_inputs[q:q+n_sit_sot]
self.outer_in_sit_sot = outer_inputs[p:p + n_sit_sot]
self.inner_in_sit_sot = inner_inputs[q:q + n_sit_sot]
p += n_sit_sot
q += n_sit_sot
n_shared_outs = info['n_shared_outs']
self.outer_in_shared = outer_inputs[p:p+n_shared_outs]
self.inner_in_shared = inner_inputs[q:q+n_shared_outs]
self.outer_in_shared = outer_inputs[p:p + n_shared_outs]
self.inner_in_shared = inner_inputs[q:q + n_shared_outs]
p += n_shared_outs
q += n_shared_outs
n_nit_sot = info['n_nit_sot']
self.outer_in_nit_sot = outer_inputs[p:p+n_nit_sot]
self.outer_in_nit_sot = outer_inputs[p:p + n_nit_sot]
p += n_nit_sot
self.outer_in_non_seqs = outer_inputs[p:]
......@@ -778,40 +779,39 @@ class scan_args(object):
self.mit_mot_out_slices = info['mit_mot_out_slices']
n_mit_mot_outs = info['n_mit_mot_outs']
self.outer_out_mit_mot = outer_outputs[p:p+n_mit_mot]
iomm = inner_outputs[q:q+n_mit_mot_outs]
self.outer_out_mit_mot = outer_outputs[p:p + n_mit_mot]
iomm = inner_outputs[q:q + n_mit_mot_outs]
self.inner_out_mit_mot = []
qq = 0
for sl in self.mit_mot_out_slices:
self.inner_out_mit_mot.append(iomm[qq:qq+len(sl)])
self.inner_out_mit_mot.append(iomm[qq:qq + len(sl)])
qq += len(sl)
p += n_mit_mot
q += n_mit_mot_outs
self.outer_out_mit_sot = outer_outputs[p:p+n_mit_sot]
self.inner_out_mit_sot = inner_outputs[q:q+n_mit_sot]
self.outer_out_mit_sot = outer_outputs[p:p + n_mit_sot]
self.inner_out_mit_sot = inner_outputs[q:q + n_mit_sot]
p += n_mit_sot
q += n_mit_sot
self.outer_out_sit_sot = outer_outputs[p:p+n_sit_sot]
self.inner_out_sit_sot = inner_outputs[q:q+n_sit_sot]
self.outer_out_sit_sot = outer_outputs[p:p + n_sit_sot]
self.inner_out_sit_sot = inner_outputs[q:q + n_sit_sot]
p += n_sit_sot
q += n_sit_sot
self.outer_out_nit_sot = outer_outputs[p:p+n_nit_sot]
self.inner_out_nit_sot = inner_outputs[q:q+n_nit_sot]
self.outer_out_nit_sot = outer_outputs[p:p + n_nit_sot]
self.inner_out_nit_sot = inner_outputs[q:q + n_nit_sot]
p += n_nit_sot
q += n_nit_sot
self.outer_out_shared = outer_outputs[p:p+n_shared_outs]
self.inner_out_shared = inner_outputs[q:q+n_shared_outs]
self.outer_out_shared = outer_outputs[p:p + n_shared_outs]
self.inner_out_shared = inner_outputs[q:q + n_shared_outs]
p += n_shared_outs
q += n_shared_outs
self.other_info = dict()
for k in ('truncate_gradient', 'name', 'mode', 'inplace',
'gpu','as_while', 'profile'):
'gpu', 'as_while', 'profile'):
self.other_info[k] = info[k]
inner_inputs = property(lambda self: (self.inner_in_seqs +
......@@ -842,18 +842,19 @@ class scan_args(object):
self.outer_out_nit_sot +
self.outer_out_shared))
info = property(lambda self: dict(n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
tap_array=(self.mit_mot_in_slices +
self.mit_sot_in_slices +
[[-1]] * len(self.inner_in_sit_sot)),
n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=self.mit_mot_out_slices,
**self.other_info))
info = property(lambda self: dict(
n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
tap_array=(self.mit_mot_in_slices +
self.mit_sot_in_slices +
[[-1]] * len(self.inner_in_sit_sot)),
n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=self.mit_mot_out_slices,
**self.other_info))
def __copy__(self):
res = object.__new__(type(self))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论