提交 b3cf9269 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #177 from nouiz/important_fix

Important fix Everything looks good
......@@ -8,7 +8,7 @@ The elemwise fct are also used with scalar operation! So it can happen that ndim
import StringIO, sys
import numpy
from theano import Op, Type, Apply, Variable, Constant
from theano import tensor, scalar
from theano import tensor, scalar, gof
import logging, copy
_logger_name = 'theano.sandbox.cuda.elemwise'
......@@ -42,8 +42,12 @@ class NaiveAlgo(object):
:param scalar_op: the scalar operation to execute on each element.
:param sync: if True, will wait after the kernel launch and check for error call.
"""
if scalar_op.c_support_code_apply(node=None, nodename="nodename"):
raise SupportCodeError(scalar_op)
try:
code = scalar_op.c_support_code_apply(node=None, name="nodename")
if code:
raise SupportCodeError(scalar_op)
except gof.utils.MethodNotDefined:
pass
self.scalar_op = scalar_op
self.sync = sync
self.inplace_pattern = inplace_pattern
......
......@@ -2097,14 +2097,14 @@ class Composite(ScalarOp):
return ()
return tuple(rval)
def c_support_code_apply(self, node, nodename):
def c_support_code_apply(self, node, name):
rval = []
for subnode, subnodename in zip(self.env.toposort(), self.nodenames):
try:
rval.append(
subnode.op.c_support_code_apply(
subnode,
subnodename % dict(nodename=nodename)))
subnodename % dict(nodename=name)))
except gof.utils.MethodNotDefined:
pass
return "\n".join(rval)
......
......@@ -4,11 +4,11 @@ This module provides utility functions for the Scan Op
See scan.py for details on scan
"""
__docformat__ = 'restructedtext en'
__authors__ = ( "Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron")
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
......@@ -16,7 +16,6 @@ import copy
import logging
import numpy
from theano import config
from theano.compile.pfunc import rebuild_collect_shared
from theano import gof
from theano import tensor, scalar
......@@ -30,7 +29,8 @@ import theano
# Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_utils')
def safe_new(x, tag = ''):
def safe_new(x, tag=''):
"""
Internal function that constructs a new variable from x with the same
type, but with a different name ( old name + tag). This function is used
......@@ -81,7 +81,7 @@ class until(object):
assert self.condition.ndim == 0
def traverse(out, x,x_copy, d):
def traverse(out, x, x_copy, d):
''' Function used by scan to parse the tree and figure out which nodes
it needs to replace. There are two options :
1) x and x_copy or on host, then you would replace x with x_copy
......@@ -111,10 +111,10 @@ def traverse(out, x,x_copy, d):
def hash_listsDictsTuples(x):
hash_value = 0
if isinstance(x, dict):
for k,v in x.iteritems():
for k, v in x.iteritems():
hash_value ^= hash_listsDictsTuples(k)
hash_value ^= hash_listsDictsTuples(v)
elif isinstance(x, (list,tuple)):
elif isinstance(x, (list, tuple)):
for v in x:
hash_value ^= hash_listsDictsTuples(v)
else:
......@@ -122,10 +122,10 @@ def hash_listsDictsTuples(x):
return hash_value
def clone( output
, replace = None
, strict = True
, copy_inputs = True):
def clone(output,
replace=None,
strict=True,
copy_inputs=True):
"""
Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding
......@@ -140,17 +140,16 @@ def clone( output
replaced by what
"""
inps, outs, other_stuff = rebuild_collect_shared( output
, []
, replace
, []
, strict
, copy_inputs
)
inps, outs, other_stuff = rebuild_collect_shared(output,
[],
replace,
[],
strict,
copy_inputs
)
return outs
def get_updates_and_outputs(ls):
"""
This function tries to recognize the updates dictionary, the
......@@ -160,7 +159,7 @@ def get_updates_and_outputs(ls):
"""
def is_outputs(elem):
if (isinstance(elem, (list,tuple)) and
if (isinstance(elem, (list, tuple)) and
all([isinstance(x, theano.Variable) for x in elem])):
return True
if isinstance(elem, theano.Variable):
......@@ -172,7 +171,7 @@ def get_updates_and_outputs(ls):
return True
# Dictionaries can be given as lists of tuples
if (isinstance(elem, (list, tuple)) and
all([isinstance(x, (list,tuple)) and len(x) ==2
all([isinstance(x, (list, tuple)) and len(x) == 2
for x in elem])):
return True
return False
......@@ -204,13 +203,13 @@ def get_updates_and_outputs(ls):
if is_updates(ls[1]):
return (None, _list(ls[0]), dict(ls[1]))
elif is_condition(ls[1]):
return ( ls[1].condition, _list(ls[0]), {})
return (ls[1].condition, _list(ls[0]), {})
else:
raise ValueError(error_msg)
elif is_updates(ls[0]):
if is_outputs(ls[1]):
_logger.warning(deprication_msg)
return ( None, _list(ls[1]), dict(ls[0]) )
return (None, _list(ls[1]), dict(ls[0]))
elif is_condition(ls[1]):
return (ls[1].condition, [], dict(ls[0]))
else:
......@@ -251,7 +250,7 @@ def isNaN_or_Inf_or_None(x):
isStr = False
if not isNaN and not isInf:
try:
val = get_constant_value(x)
val = get_constant_value(x)
isInf = numpy.isinf(val)
isNaN = numpy.isnan(val)
except Exception:
......@@ -264,7 +263,7 @@ def isNaN_or_Inf_or_None(x):
return isNone or isNaN or isInf or isStr
def expand( tensor_var, size):
def expand(tensor_var, size):
'''
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
by adding 0s at the end of the tensor.
......@@ -272,13 +271,14 @@ def expand( tensor_var, size):
# Corner case that I might use in an optimization
if size == 0:
return tensor_var
shapes = [ tensor_var.shape[x] for x in xrange(tensor_var.ndim) ]
zeros_shape = [size+shapes[0]] + shapes[1:]
empty = tensor.zeros( zeros_shape
, dtype = tensor_var.dtype)
shapes = [tensor_var.shape[x] for x in xrange(tensor_var.ndim)]
zeros_shape = [size + shapes[0]] + shapes[1:]
empty = tensor.zeros(zeros_shape,
dtype=tensor_var.dtype)
return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
'''
Checks if to theano graphs represent the same computations (with
equivalence of inputs defined by map). Inputs are always assumed
......@@ -289,8 +289,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if in_ys is None:
in_ys = []
for x,y in zip(xs,ys):
for x, y in zip(xs, ys):
if x.owner and not y.owner:
return False
if y.owner and not x.owner:
......@@ -300,7 +299,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
return False
if len(in_xs) != len(in_ys):
return False
for _x,_y in zip(in_xs, in_ys):
for _x, _y in zip(in_xs, in_ys):
if _x.type != _y.type:
return False
......@@ -308,17 +307,17 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
nds_y = gof.graph.io_toposort(in_ys, ys)
if len(nds_x) != len(nds_y):
return False
common = set(zip(in_xs,in_ys))
common = set(zip(in_xs, in_ys))
n_nodes = len(nds_x)
cont = True
idx = 0
for dx,dy in zip(xs,ys):
for dx, dy in zip(xs, ys):
if not dx.owner or not dy.owner:
if dy.owner or dx.owner:
return False
elif (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)):
if not ( numpy.all(dx.data == dy.data) and
if not (numpy.all(dx.data == dy.data) and
dx.dtype == dy.dtype and
dx.data.shape == dy.data.shape):
return False
......@@ -329,7 +328,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if dx.type != dy.type:
return False
else:
if (dx,dy) not in common:
if (dx, dy) not in common:
return False
while cont and idx < n_nodes:
......@@ -342,9 +341,9 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
elif len(nd_x.outputs) != len(nd_y.outputs):
cont = False
else:
for dx,dy in zip(nd_x.inputs, nd_y.inputs):
if (dx,dy) not in common:
if strict and dx!= dy:
for dx, dy in zip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common:
if strict and dx != dy:
if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and
......@@ -359,32 +358,27 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
cont = cont and (dx.type == dy.type)
if cont:
for dx,dy in zip(nd_x.outputs, nd_y.outputs):
common.add((dx,dy))
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
common.add((dx, dy))
idx += 1
return cont
def infer_shape(outs, inputs, input_shapes):
'''
Compute the shape of the outputs given the shape of the inputs
of a theano graph.
We do it this way to don't compile the inner function just to get
the shape. Change to ShapeFeature could request change in this function.
'''
# We use a ShapeFeature because it has all the necessary logic inside.
# We don't use the Feature interface, so we need to initialize some
# things by hand.
# We use a ShapeFeature because it has all the necessary logic
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty env, otherwise we will
# need to do it manually
shape_feature = tensor.opt.ShapeFeature()
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
# All keys of shape_of should be either in valid or in invalid
shape_feature.shape_of = {}
# To avoid merging lots of ones together.
shape_feature.lscalar_one = tensor.constant(1, dtype='int64')
shape_feature.on_attach(theano.gof.Env([], []))
# Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes):
......@@ -418,6 +412,7 @@ def infer_shape(outs, inputs, input_shapes):
ret.append(shape_feature.shape_of[o])
return ret
class Validator(object):
def __init__(self, valid=[], invalid=[], valid_equivalent={}):
'''
......@@ -496,35 +491,35 @@ def scan_can_remove_outs(op, out_idxs):
the first one with the indices of outs that can be removed, the second
with the outputs that can not be removed.
'''
non_removable = [ o for i,o in enumerate(op.outputs) if i not in
non_removable = [o for i, o in enumerate(op.outputs) if i not in
out_idxs]
required_inputs = gof.graph.inputs(non_removable)
out_ins = []
offset = op.n_seqs
offset = op.n_seqs
lim = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot
for idx in range(lim):
n_ins = len(op.info['tap_array'][idx])
out_ins += [op.inputs[offset:offset+n_ins]]
offset += n_ins
out_ins += [ [] for k in xrange(op.n_nit_sot) ]
out_ins += [ [op.inputs[offset+k]] for k in xrange(op.n_shared_outs)]
n_ins = len(op.info['tap_array'][idx])
out_ins += [op.inputs[offset:offset + n_ins]]
offset += n_ins
out_ins += [[] for k in xrange(op.n_nit_sot)]
out_ins += [[op.inputs[offset + k]] for k in xrange(op.n_shared_outs)]
added = True
out_idxs_mask = [1 for idx in out_idxs]
while added:
added = False
for pos,idx in enumerate(out_idxs):
if ( out_idxs_mask[pos] and
numpy.any([x in required_inputs for x in out_ins[idx]]) ):
for pos, idx in enumerate(out_idxs):
if (out_idxs_mask[pos] and
numpy.any([x in required_inputs for x in out_ins[idx]])):
# This output is required ..
out_idxs_mask[pos] = 0
required_inputs += gof.graph.inputs([op.outputs[idx]])
added = True
required_outs = [x for i,x in enumerate(out_idxs)
required_outs = [x for i, x in enumerate(out_idxs)
if out_idxs_mask[i] == 0]
not_required = [x for i,x in enumerate(out_idxs) if out_idxs_mask[i]==1]
not_required = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 1]
return (required_outs, not_required)
......@@ -539,107 +534,107 @@ def compress_outs(op, not_required, inputs):
node inputs, and changing the dictionary.
'''
info = {}
info['tap_array'] = []
info['n_seqs'] = op.info['n_seqs']
info['n_mit_mot'] = 0
info['n_mit_mot_outs'] = 0
info['tap_array'] = []
info['n_seqs'] = op.info['n_seqs']
info['n_mit_mot'] = 0
info['n_mit_mot_outs'] = 0
info['mit_mot_out_slices'] = []
info['n_mit_sot'] = 0
info['n_sit_sot'] = 0
info['n_shared_outs'] = 0
info['n_nit_sot'] = 0
info['truncate_gradient'] = op.info['truncate_gradient']
info['name'] = op.info['name']
info['inplace'] = op.info['inplace']
info['gpu'] = op.info['gpu']
info['mode'] = op.info['mode']
info['as_while'] = op.info['as_while']
info['profile'] = op.info['profile']
op_inputs = op.inputs[:op.n_seqs]
op_outputs = []
info['n_mit_sot'] = 0
info['n_sit_sot'] = 0
info['n_shared_outs'] = 0
info['n_nit_sot'] = 0
info['truncate_gradient'] = op.info['truncate_gradient']
info['name'] = op.info['name']
info['inplace'] = op.info['inplace']
info['gpu'] = op.info['gpu']
info['mode'] = op.info['mode']
info['as_while'] = op.info['as_while']
info['profile'] = op.info['profile']
op_inputs = op.inputs[:op.n_seqs]
op_outputs = []
node_inputs = inputs[:op.n_seqs + 1]
map_old_new = {}
offset = 0
ni_offset = op.n_seqs+1
i_offset = op.n_seqs
o_offset = 0
curr_pos = 0
ni_offset = op.n_seqs + 1
i_offset = op.n_seqs
o_offset = 0
curr_pos = 0
for idx in xrange(op.info['n_mit_mot']):
if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info['n_mit_mot'] += 1
info['tap_array'] += [op.tap_array[offset+idx]]
info['mit_mot_out_slices'] += [op.mit_mot_out_slices[offset+idx]]
info['tap_array'] += [op.tap_array[offset + idx]]
info['mit_mot_out_slices'] += [op.mit_mot_out_slices[offset + idx]]
# input taps
for jdx in op.tap_array[offset+idx]:
for jdx in op.tap_array[offset + idx]:
op_inputs += [op.inputs[i_offset]]
i_offset += 1
# output taps
for jdx in op.mit_mot_out_slices[offset+idx]:
for jdx in op.mit_mot_out_slices[offset + idx]:
op_outputs += [op.outputs[o_offset]]
o_offset += 1
# node inputs
node_inputs += [inputs[ni_offset+idx]]
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset += len(op.mit_mot_out_slices[offset+idx])
i_offset += len(op.tap_array[offset+idx])
o_offset += len(op.mit_mot_out_slices[offset + idx])
i_offset += len(op.tap_array[offset + idx])
info['n_mit_mot_outs'] = len(op_outputs)
offset += op.n_mit_mot
offset += op.n_mit_mot
ni_offset += op.n_mit_mot
for idx in xrange(op.info['n_mit_sot']):
if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info['n_mit_sot'] += 1
info['tap_array'] += [op.tap_array[offset+idx]]
info['tap_array'] += [op.tap_array[offset + idx]]
#input taps
for jdx in op.tap_array[offset+idx]:
for jdx in op.tap_array[offset + idx]:
op_inputs += [op.inputs[i_offset]]
i_offset += 1
#output taps
op_outputs += [op.outputs[o_offset]]
o_offset+=1
o_offset += 1
#node inputs
node_inputs += [inputs[ni_offset+idx]]
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset+=1
i_offset+=len(op.tap_array[offset+idx])
o_offset += 1
i_offset += len(op.tap_array[offset + idx])
offset += op.n_mit_sot
offset += op.n_mit_sot
ni_offset += op.n_mit_sot
for idx in xrange(op.info['n_sit_sot']):
if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info['n_sit_sot'] += 1
info['tap_array'] += [op.tap_array[offset+idx]]
info['tap_array'] += [op.tap_array[offset + idx]]
#input taps
op_inputs += [op.inputs[i_offset]]
i_offset += 1
#output taps
op_outputs += [op.outputs[o_offset]]
o_offset+=1
o_offset += 1
#node inputs
node_inputs += [inputs[ni_offset+idx]]
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset+=1
i_offset+=1
o_offset += 1
i_offset += 1
offset += op.n_sit_sot
offset += op.n_sit_sot
ni_offset += op.n_sit_sot
nit_sot_ins = []
for idx in xrange(op.info['n_nit_sot']):
if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info['n_nit_sot'] += 1
op_outputs += [op.outputs[o_offset]]
o_offset+=1
nit_sot_ins += [inputs[ni_offset+idx+op.n_shared_outs]]
o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]]
else:
o_offset += 1
......@@ -647,14 +642,14 @@ def compress_outs(op, not_required, inputs):
shared_ins = []
for idx in xrange(op.info['n_shared_outs']):
if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info['n_shared_outs'] += 1
op_outputs += [ op.outputs[o_offset]]
o_offset +=1
op_inputs += [ op.inputs[i_offset]]
op_outputs += [op.outputs[o_offset]]
o_offset += 1
op_inputs += [op.inputs[i_offset]]
i_offset += 1
shared_ins += [inputs[ni_offset+idx]]
shared_ins += [inputs[ni_offset + idx]]
else:
o_offset += 1
i_offset += 1
......@@ -662,14 +657,15 @@ def compress_outs(op, not_required, inputs):
node_inputs += nit_sot_ins
# other stuff
op_inputs += op.inputs[i_offset:]
node_inputs += inputs[ni_offset+op.n_shared_outs+op.n_nit_sot:]
node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot:]
if op.as_while:
op_outputs += [op.outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs)-1
map_old_new[o_offset] = len(op_outputs) - 1
#map_old_new[len(op_outputs)-1] = o_offset
return (op_inputs, op_outputs, info, node_inputs, map_old_new)
def find_up(l_node, f_node):
r"""
Goes up in the graph and returns True if a node in nodes is found.
......@@ -678,11 +674,12 @@ def find_up(l_node, f_node):
l_outs = l_node.outputs
else:
l_outs = l_node
l_ins = gof.graph.inputs(l_outs)
l_ins = gof.graph.inputs(l_outs)
nodes = gof.graph.io_toposort(l_ins, l_outs)
return f_node in nodes
def reconstruct_graph(inputs, outputs, tag = None):
def reconstruct_graph(inputs, outputs, tag=None):
"""
Different interface to clone, that allows you to pass inputs.
Compared to clone, this method always replaces the inputs with
......@@ -691,7 +688,7 @@ def reconstruct_graph(inputs, outputs, tag = None):
"""
if tag is None:
tag = ''
nw_inputs = [safe_new(x,tag) for x in inputs]
nw_inputs = [safe_new(x, tag) for x in inputs]
givens = {}
for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x
......@@ -700,9 +697,10 @@ def reconstruct_graph(inputs, outputs, tag = None):
if isinstance(inp, theano.Constant):
givens[inp] = inp.clone()
nw_outputs = clone( outputs, replace=givens)
nw_outputs = clone(outputs, replace=givens)
return (nw_inputs, nw_outputs)
class scan_args(object):
"""Parses the inputs and outputs of scan in an easy to manipulate format"""
def __init__(self, outer_inputs, outer_outputs,
......@@ -714,14 +712,14 @@ class scan_args(object):
inner_outputs = rval[1][:-1]
else:
inner_outputs = rval[1]
inner_inputs = rval[0]
inner_inputs = rval[0]
p = 1
q = 0
n_seqs = info['n_seqs']
self.outer_in_seqs = outer_inputs[p:p+n_seqs]
self.inner_in_seqs = inner_inputs[q:q+n_seqs]
self.outer_in_seqs = outer_inputs[p:p + n_seqs]
self.inner_in_seqs = inner_inputs[q:q + n_seqs]
p += n_seqs
q += n_seqs
......@@ -729,46 +727,47 @@ class scan_args(object):
n_mit_sot = info['n_mit_sot']
self.mit_mot_in_slices = info['tap_array'][:n_mit_mot]
self.mit_sot_in_slices = info['tap_array'][n_mit_mot:n_mit_mot+n_mit_sot]
self.mit_sot_in_slices = info['tap_array'][
n_mit_mot:n_mit_mot + n_mit_sot]
n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)
iimm = inner_inputs[q:q+n_mit_mot_ins]
iimm = inner_inputs[q:q + n_mit_mot_ins]
self.inner_in_mit_mot = []
qq = 0
for sl in self.mit_mot_in_slices:
self.inner_in_mit_mot.append(iimm[qq:qq+len(sl)])
self.inner_in_mit_mot.append(iimm[qq:qq + len(sl)])
qq += len(sl)
q += n_mit_mot_ins
iims = inner_inputs[q:q+n_mit_sot_ins]
iims = inner_inputs[q:q + n_mit_sot_ins]
self.inner_in_mit_sot = []
qq = 0
for sl in self.mit_sot_in_slices:
self.inner_in_mit_sot.append(iims[qq:qq+len(sl)])
self.inner_in_mit_sot.append(iims[qq:qq + len(sl)])
qq += len(sl)
q += n_mit_sot_ins
self.outer_in_mit_mot = outer_inputs[p:p+n_mit_mot]
self.outer_in_mit_mot = outer_inputs[p:p + n_mit_mot]
p += n_mit_mot
self.outer_in_mit_sot = outer_inputs[p:p+n_mit_sot]
self.outer_in_mit_sot = outer_inputs[p:p + n_mit_sot]
p += n_mit_sot
n_sit_sot = info['n_sit_sot']
self.outer_in_sit_sot = outer_inputs[p:p+n_sit_sot]
self.inner_in_sit_sot = inner_inputs[q:q+n_sit_sot]
self.outer_in_sit_sot = outer_inputs[p:p + n_sit_sot]
self.inner_in_sit_sot = inner_inputs[q:q + n_sit_sot]
p += n_sit_sot
q += n_sit_sot
n_shared_outs = info['n_shared_outs']
self.outer_in_shared = outer_inputs[p:p+n_shared_outs]
self.inner_in_shared = inner_inputs[q:q+n_shared_outs]
self.outer_in_shared = outer_inputs[p:p + n_shared_outs]
self.inner_in_shared = inner_inputs[q:q + n_shared_outs]
p += n_shared_outs
q += n_shared_outs
n_nit_sot = info['n_nit_sot']
self.outer_in_nit_sot = outer_inputs[p:p+n_nit_sot]
self.outer_in_nit_sot = outer_inputs[p:p + n_nit_sot]
p += n_nit_sot
self.outer_in_non_seqs = outer_inputs[p:]
......@@ -780,40 +779,39 @@ class scan_args(object):
self.mit_mot_out_slices = info['mit_mot_out_slices']
n_mit_mot_outs = info['n_mit_mot_outs']
self.outer_out_mit_mot = outer_outputs[p:p+n_mit_mot]
iomm = inner_outputs[q:q+n_mit_mot_outs]
self.outer_out_mit_mot = outer_outputs[p:p + n_mit_mot]
iomm = inner_outputs[q:q + n_mit_mot_outs]
self.inner_out_mit_mot = []
qq = 0
for sl in self.mit_mot_out_slices:
self.inner_out_mit_mot.append(iomm[qq:qq+len(sl)])
self.inner_out_mit_mot.append(iomm[qq:qq + len(sl)])
qq += len(sl)
p += n_mit_mot
q += n_mit_mot_outs
self.outer_out_mit_sot = outer_outputs[p:p+n_mit_sot]
self.inner_out_mit_sot = inner_outputs[q:q+n_mit_sot]
self.outer_out_mit_sot = outer_outputs[p:p + n_mit_sot]
self.inner_out_mit_sot = inner_outputs[q:q + n_mit_sot]
p += n_mit_sot
q += n_mit_sot
self.outer_out_sit_sot = outer_outputs[p:p+n_sit_sot]
self.inner_out_sit_sot = inner_outputs[q:q+n_sit_sot]
self.outer_out_sit_sot = outer_outputs[p:p + n_sit_sot]
self.inner_out_sit_sot = inner_outputs[q:q + n_sit_sot]
p += n_sit_sot
q += n_sit_sot
self.outer_out_nit_sot = outer_outputs[p:p+n_nit_sot]
self.inner_out_nit_sot = inner_outputs[q:q+n_nit_sot]
self.outer_out_nit_sot = outer_outputs[p:p + n_nit_sot]
self.inner_out_nit_sot = inner_outputs[q:q + n_nit_sot]
p += n_nit_sot
q += n_nit_sot
self.outer_out_shared = outer_outputs[p:p+n_shared_outs]
self.inner_out_shared = inner_outputs[q:q+n_shared_outs]
self.outer_out_shared = outer_outputs[p:p + n_shared_outs]
self.inner_out_shared = inner_outputs[q:q + n_shared_outs]
p += n_shared_outs
q += n_shared_outs
self.other_info = dict()
for k in ('truncate_gradient', 'name', 'mode', 'inplace',
'gpu','as_while', 'profile'):
'gpu', 'as_while', 'profile'):
self.other_info[k] = info[k]
inner_inputs = property(lambda self: (self.inner_in_seqs +
......@@ -844,18 +842,19 @@ class scan_args(object):
self.outer_out_nit_sot +
self.outer_out_shared))
info = property(lambda self: dict(n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
tap_array=(self.mit_mot_in_slices +
self.mit_sot_in_slices +
[[-1]] * len(self.inner_in_sit_sot)),
n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=self.mit_mot_out_slices,
**self.other_info))
info = property(lambda self: dict(
n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
tap_array=(self.mit_mot_in_slices +
self.mit_sot_in_slices +
[[-1]] * len(self.inner_in_sit_sot)),
n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=self.mit_mot_out_slices,
**self.other_info))
def __copy__(self):
res = object.__new__(type(self))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论