提交 b3cf9269 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #177 from nouiz/important_fix

Important fix Everything looks good
...@@ -8,7 +8,7 @@ The elemwise fct are also used with scalar operation! So it can happen that ndim ...@@ -8,7 +8,7 @@ The elemwise fct are also used with scalar operation! So it can happen that ndim
import StringIO, sys import StringIO, sys
import numpy import numpy
from theano import Op, Type, Apply, Variable, Constant from theano import Op, Type, Apply, Variable, Constant
from theano import tensor, scalar from theano import tensor, scalar, gof
import logging, copy import logging, copy
_logger_name = 'theano.sandbox.cuda.elemwise' _logger_name = 'theano.sandbox.cuda.elemwise'
...@@ -42,8 +42,12 @@ class NaiveAlgo(object): ...@@ -42,8 +42,12 @@ class NaiveAlgo(object):
:param scalar_op: the scalar operation to execute on each element. :param scalar_op: the scalar operation to execute on each element.
:param sync: if True, will wait after the kernel launch and check for error call. :param sync: if True, will wait after the kernel launch and check for error call.
""" """
if scalar_op.c_support_code_apply(node=None, nodename="nodename"): try:
code = scalar_op.c_support_code_apply(node=None, name="nodename")
if code:
raise SupportCodeError(scalar_op) raise SupportCodeError(scalar_op)
except gof.utils.MethodNotDefined:
pass
self.scalar_op = scalar_op self.scalar_op = scalar_op
self.sync = sync self.sync = sync
self.inplace_pattern = inplace_pattern self.inplace_pattern = inplace_pattern
......
...@@ -2097,14 +2097,14 @@ class Composite(ScalarOp): ...@@ -2097,14 +2097,14 @@ class Composite(ScalarOp):
return () return ()
return tuple(rval) return tuple(rval)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, name):
rval = [] rval = []
for subnode, subnodename in zip(self.env.toposort(), self.nodenames): for subnode, subnodename in zip(self.env.toposort(), self.nodenames):
try: try:
rval.append( rval.append(
subnode.op.c_support_code_apply( subnode.op.c_support_code_apply(
subnode, subnode,
subnodename % dict(nodename=nodename))) subnodename % dict(nodename=name)))
except gof.utils.MethodNotDefined: except gof.utils.MethodNotDefined:
pass pass
return "\n".join(rval) return "\n".join(rval)
......
...@@ -4,7 +4,7 @@ This module provides utility functions for the Scan Op ...@@ -4,7 +4,7 @@ This module provides utility functions for the Scan Op
See scan.py for details on scan See scan.py for details on scan
""" """
__docformat__ = 'restructedtext en' __docformat__ = 'restructedtext en'
__authors__ = ( "Razvan Pascanu " __authors__ = ("Razvan Pascanu "
"Frederic Bastien " "Frederic Bastien "
"James Bergstra " "James Bergstra "
"Pascal Lamblin " "Pascal Lamblin "
...@@ -16,7 +16,6 @@ import copy ...@@ -16,7 +16,6 @@ import copy
import logging import logging
import numpy import numpy
from theano import config
from theano.compile.pfunc import rebuild_collect_shared from theano.compile.pfunc import rebuild_collect_shared
from theano import gof from theano import gof
from theano import tensor, scalar from theano import tensor, scalar
...@@ -30,7 +29,8 @@ import theano ...@@ -30,7 +29,8 @@ import theano
# Logging function for sending warning or info # Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_utils') _logger = logging.getLogger('theano.scan_utils')
def safe_new(x, tag = ''):
def safe_new(x, tag=''):
""" """
Internal function that constructs a new variable from x with the same Internal function that constructs a new variable from x with the same
type, but with a different name ( old name + tag). This function is used type, but with a different name ( old name + tag). This function is used
...@@ -81,7 +81,7 @@ class until(object): ...@@ -81,7 +81,7 @@ class until(object):
assert self.condition.ndim == 0 assert self.condition.ndim == 0
def traverse(out, x,x_copy, d): def traverse(out, x, x_copy, d):
''' Function used by scan to parse the tree and figure out which nodes ''' Function used by scan to parse the tree and figure out which nodes
it needs to replace. There are two options : it needs to replace. There are two options :
1) x and x_copy or on host, then you would replace x with x_copy 1) x and x_copy or on host, then you would replace x with x_copy
...@@ -111,10 +111,10 @@ def traverse(out, x,x_copy, d): ...@@ -111,10 +111,10 @@ def traverse(out, x,x_copy, d):
def hash_listsDictsTuples(x): def hash_listsDictsTuples(x):
hash_value = 0 hash_value = 0
if isinstance(x, dict): if isinstance(x, dict):
for k,v in x.iteritems(): for k, v in x.iteritems():
hash_value ^= hash_listsDictsTuples(k) hash_value ^= hash_listsDictsTuples(k)
hash_value ^= hash_listsDictsTuples(v) hash_value ^= hash_listsDictsTuples(v)
elif isinstance(x, (list,tuple)): elif isinstance(x, (list, tuple)):
for v in x: for v in x:
hash_value ^= hash_listsDictsTuples(v) hash_value ^= hash_listsDictsTuples(v)
else: else:
...@@ -122,10 +122,10 @@ def hash_listsDictsTuples(x): ...@@ -122,10 +122,10 @@ def hash_listsDictsTuples(x):
return hash_value return hash_value
def clone( output def clone(output,
, replace = None replace=None,
, strict = True strict=True,
, copy_inputs = True): copy_inputs=True):
""" """
Function that allows replacing subgraphs of a computational Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding graph. It returns a copy of the initial subgraph with the corresponding
...@@ -140,17 +140,16 @@ def clone( output ...@@ -140,17 +140,16 @@ def clone( output
replaced by what replaced by what
""" """
inps, outs, other_stuff = rebuild_collect_shared( output inps, outs, other_stuff = rebuild_collect_shared(output,
, [] [],
, replace replace,
, [] [],
, strict strict,
, copy_inputs copy_inputs
) )
return outs return outs
def get_updates_and_outputs(ls): def get_updates_and_outputs(ls):
""" """
This function tries to recognize the updates dictionary, the This function tries to recognize the updates dictionary, the
...@@ -160,7 +159,7 @@ def get_updates_and_outputs(ls): ...@@ -160,7 +159,7 @@ def get_updates_and_outputs(ls):
""" """
def is_outputs(elem): def is_outputs(elem):
if (isinstance(elem, (list,tuple)) and if (isinstance(elem, (list, tuple)) and
all([isinstance(x, theano.Variable) for x in elem])): all([isinstance(x, theano.Variable) for x in elem])):
return True return True
if isinstance(elem, theano.Variable): if isinstance(elem, theano.Variable):
...@@ -172,7 +171,7 @@ def get_updates_and_outputs(ls): ...@@ -172,7 +171,7 @@ def get_updates_and_outputs(ls):
return True return True
# Dictionaries can be given as lists of tuples # Dictionaries can be given as lists of tuples
if (isinstance(elem, (list, tuple)) and if (isinstance(elem, (list, tuple)) and
all([isinstance(x, (list,tuple)) and len(x) ==2 all([isinstance(x, (list, tuple)) and len(x) == 2
for x in elem])): for x in elem])):
return True return True
return False return False
...@@ -204,13 +203,13 @@ def get_updates_and_outputs(ls): ...@@ -204,13 +203,13 @@ def get_updates_and_outputs(ls):
if is_updates(ls[1]): if is_updates(ls[1]):
return (None, _list(ls[0]), dict(ls[1])) return (None, _list(ls[0]), dict(ls[1]))
elif is_condition(ls[1]): elif is_condition(ls[1]):
return ( ls[1].condition, _list(ls[0]), {}) return (ls[1].condition, _list(ls[0]), {})
else: else:
raise ValueError(error_msg) raise ValueError(error_msg)
elif is_updates(ls[0]): elif is_updates(ls[0]):
if is_outputs(ls[1]): if is_outputs(ls[1]):
_logger.warning(deprication_msg) _logger.warning(deprication_msg)
return ( None, _list(ls[1]), dict(ls[0]) ) return (None, _list(ls[1]), dict(ls[0]))
elif is_condition(ls[1]): elif is_condition(ls[1]):
return (ls[1].condition, [], dict(ls[0])) return (ls[1].condition, [], dict(ls[0]))
else: else:
...@@ -264,7 +263,7 @@ def isNaN_or_Inf_or_None(x): ...@@ -264,7 +263,7 @@ def isNaN_or_Inf_or_None(x):
return isNone or isNaN or isInf or isStr return isNone or isNaN or isInf or isStr
def expand( tensor_var, size): def expand(tensor_var, size):
''' '''
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..) Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
by adding 0s at the end of the tensor. by adding 0s at the end of the tensor.
...@@ -272,13 +271,14 @@ def expand( tensor_var, size): ...@@ -272,13 +271,14 @@ def expand( tensor_var, size):
# Corner case that I might use in an optimization # Corner case that I might use in an optimization
if size == 0: if size == 0:
return tensor_var return tensor_var
shapes = [ tensor_var.shape[x] for x in xrange(tensor_var.ndim) ] shapes = [tensor_var.shape[x] for x in xrange(tensor_var.ndim)]
zeros_shape = [size+shapes[0]] + shapes[1:] zeros_shape = [size + shapes[0]] + shapes[1:]
empty = tensor.zeros( zeros_shape empty = tensor.zeros(zeros_shape,
, dtype = tensor_var.dtype) dtype=tensor_var.dtype)
return tensor.set_subtensor(empty[:shapes[0]], tensor_var) return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
''' '''
Checks if to theano graphs represent the same computations (with Checks if to theano graphs represent the same computations (with
equivalence of inputs defined by map). Inputs are always assumed equivalence of inputs defined by map). Inputs are always assumed
...@@ -289,8 +289,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True): ...@@ -289,8 +289,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if in_ys is None: if in_ys is None:
in_ys = [] in_ys = []
for x, y in zip(xs, ys):
for x,y in zip(xs,ys):
if x.owner and not y.owner: if x.owner and not y.owner:
return False return False
if y.owner and not x.owner: if y.owner and not x.owner:
...@@ -300,7 +299,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True): ...@@ -300,7 +299,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
return False return False
if len(in_xs) != len(in_ys): if len(in_xs) != len(in_ys):
return False return False
for _x,_y in zip(in_xs, in_ys): for _x, _y in zip(in_xs, in_ys):
if _x.type != _y.type: if _x.type != _y.type:
return False return False
...@@ -308,17 +307,17 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True): ...@@ -308,17 +307,17 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
nds_y = gof.graph.io_toposort(in_ys, ys) nds_y = gof.graph.io_toposort(in_ys, ys)
if len(nds_x) != len(nds_y): if len(nds_x) != len(nds_y):
return False return False
common = set(zip(in_xs,in_ys)) common = set(zip(in_xs, in_ys))
n_nodes = len(nds_x) n_nodes = len(nds_x)
cont = True cont = True
idx = 0 idx = 0
for dx,dy in zip(xs,ys): for dx, dy in zip(xs, ys):
if not dx.owner or not dy.owner: if not dx.owner or not dy.owner:
if dy.owner or dx.owner: if dy.owner or dx.owner:
return False return False
elif (isinstance(dx, tensor.Constant) and elif (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)): isinstance(dy, tensor.Constant)):
if not ( numpy.all(dx.data == dy.data) and if not (numpy.all(dx.data == dy.data) and
dx.dtype == dy.dtype and dx.dtype == dy.dtype and
dx.data.shape == dy.data.shape): dx.data.shape == dy.data.shape):
return False return False
...@@ -329,7 +328,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True): ...@@ -329,7 +328,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if dx.type != dy.type: if dx.type != dy.type:
return False return False
else: else:
if (dx,dy) not in common: if (dx, dy) not in common:
return False return False
while cont and idx < n_nodes: while cont and idx < n_nodes:
...@@ -342,9 +341,9 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True): ...@@ -342,9 +341,9 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
elif len(nd_x.outputs) != len(nd_y.outputs): elif len(nd_x.outputs) != len(nd_y.outputs):
cont = False cont = False
else: else:
for dx,dy in zip(nd_x.inputs, nd_y.inputs): for dx, dy in zip(nd_x.inputs, nd_y.inputs):
if (dx,dy) not in common: if (dx, dy) not in common:
if strict and dx!= dy: if strict and dx != dy:
if (isinstance(dx, tensor.Constant) and if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)): isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and if not (numpy.all(dx.data == dy.data) and
...@@ -359,32 +358,27 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True): ...@@ -359,32 +358,27 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
cont = cont and (dx.type == dy.type) cont = cont and (dx.type == dy.type)
if cont: if cont:
for dx,dy in zip(nd_x.outputs, nd_y.outputs): for dx, dy in zip(nd_x.outputs, nd_y.outputs):
common.add((dx,dy)) common.add((dx, dy))
idx += 1 idx += 1
return cont return cont
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
''' '''
Compute the shape of the outputs given the shape of the inputs Compute the shape of the outputs given the shape of the inputs
of a theano graph. of a theano graph.
We do it this way to don't compile the inner function just to get
the shape. Change to ShapeFeature could request change in this function.
''' '''
# We use a ShapeFeature because it has all the necessary logic inside. # We use a ShapeFeature because it has all the necessary logic
# We don't use the Feature interface, so we need to initialize some # inside. We don't use the full ShapeFeature interface, but we
# things by hand. # let it initialize itself with an empty env, otherwise we will
# need to do it manually
shape_feature = tensor.opt.ShapeFeature() shape_feature = tensor.opt.ShapeFeature()
shape_feature.on_attach(theano.gof.Env([], []))
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
# All keys of shape_of should be either in valid or in invalid
shape_feature.shape_of = {}
# To avoid merging lots of ones together.
shape_feature.lscalar_one = tensor.constant(1, dtype='int64')
# Initialize shape_of with the input shapes # Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes): for inp, inp_shp in zip(inputs, input_shapes):
...@@ -418,6 +412,7 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -418,6 +412,7 @@ def infer_shape(outs, inputs, input_shapes):
ret.append(shape_feature.shape_of[o]) ret.append(shape_feature.shape_of[o])
return ret return ret
class Validator(object): class Validator(object):
def __init__(self, valid=[], invalid=[], valid_equivalent={}): def __init__(self, valid=[], invalid=[], valid_equivalent={}):
''' '''
...@@ -496,7 +491,7 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -496,7 +491,7 @@ def scan_can_remove_outs(op, out_idxs):
the first one with the indices of outs that can be removed, the second the first one with the indices of outs that can be removed, the second
with the outputs that can not be removed. with the outputs that can not be removed.
''' '''
non_removable = [ o for i,o in enumerate(op.outputs) if i not in non_removable = [o for i, o in enumerate(op.outputs) if i not in
out_idxs] out_idxs]
required_inputs = gof.graph.inputs(non_removable) required_inputs = gof.graph.inputs(non_removable)
...@@ -505,26 +500,26 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -505,26 +500,26 @@ def scan_can_remove_outs(op, out_idxs):
lim = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot lim = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot
for idx in range(lim): for idx in range(lim):
n_ins = len(op.info['tap_array'][idx]) n_ins = len(op.info['tap_array'][idx])
out_ins += [op.inputs[offset:offset+n_ins]] out_ins += [op.inputs[offset:offset + n_ins]]
offset += n_ins offset += n_ins
out_ins += [ [] for k in xrange(op.n_nit_sot) ] out_ins += [[] for k in xrange(op.n_nit_sot)]
out_ins += [ [op.inputs[offset+k]] for k in xrange(op.n_shared_outs)] out_ins += [[op.inputs[offset + k]] for k in xrange(op.n_shared_outs)]
added = True added = True
out_idxs_mask = [1 for idx in out_idxs] out_idxs_mask = [1 for idx in out_idxs]
while added: while added:
added = False added = False
for pos,idx in enumerate(out_idxs): for pos, idx in enumerate(out_idxs):
if ( out_idxs_mask[pos] and if (out_idxs_mask[pos] and
numpy.any([x in required_inputs for x in out_ins[idx]]) ): numpy.any([x in required_inputs for x in out_ins[idx]])):
# This output is required .. # This output is required ..
out_idxs_mask[pos] = 0 out_idxs_mask[pos] = 0
required_inputs += gof.graph.inputs([op.outputs[idx]]) required_inputs += gof.graph.inputs([op.outputs[idx]])
added = True added = True
required_outs = [x for i,x in enumerate(out_idxs) required_outs = [x for i, x in enumerate(out_idxs)
if out_idxs_mask[i] == 0] if out_idxs_mask[i] == 0]
not_required = [x for i,x in enumerate(out_idxs) if out_idxs_mask[i]==1] not_required = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 1]
return (required_outs, not_required) return (required_outs, not_required)
...@@ -562,84 +557,84 @@ def compress_outs(op, not_required, inputs): ...@@ -562,84 +557,84 @@ def compress_outs(op, not_required, inputs):
map_old_new = {} map_old_new = {}
offset = 0 offset = 0
ni_offset = op.n_seqs+1 ni_offset = op.n_seqs + 1
i_offset = op.n_seqs i_offset = op.n_seqs
o_offset = 0 o_offset = 0
curr_pos = 0 curr_pos = 0
for idx in xrange(op.info['n_mit_mot']): for idx in xrange(op.info['n_mit_mot']):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info['n_mit_mot'] += 1 info['n_mit_mot'] += 1
info['tap_array'] += [op.tap_array[offset+idx]] info['tap_array'] += [op.tap_array[offset + idx]]
info['mit_mot_out_slices'] += [op.mit_mot_out_slices[offset+idx]] info['mit_mot_out_slices'] += [op.mit_mot_out_slices[offset + idx]]
# input taps # input taps
for jdx in op.tap_array[offset+idx]: for jdx in op.tap_array[offset + idx]:
op_inputs += [op.inputs[i_offset]] op_inputs += [op.inputs[i_offset]]
i_offset += 1 i_offset += 1
# output taps # output taps
for jdx in op.mit_mot_out_slices[offset+idx]: for jdx in op.mit_mot_out_slices[offset + idx]:
op_outputs += [op.outputs[o_offset]] op_outputs += [op.outputs[o_offset]]
o_offset += 1 o_offset += 1
# node inputs # node inputs
node_inputs += [inputs[ni_offset+idx]] node_inputs += [inputs[ni_offset + idx]]
else: else:
o_offset += len(op.mit_mot_out_slices[offset+idx]) o_offset += len(op.mit_mot_out_slices[offset + idx])
i_offset += len(op.tap_array[offset+idx]) i_offset += len(op.tap_array[offset + idx])
info['n_mit_mot_outs'] = len(op_outputs) info['n_mit_mot_outs'] = len(op_outputs)
offset += op.n_mit_mot offset += op.n_mit_mot
ni_offset += op.n_mit_mot ni_offset += op.n_mit_mot
for idx in xrange(op.info['n_mit_sot']): for idx in xrange(op.info['n_mit_sot']):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info['n_mit_sot'] += 1 info['n_mit_sot'] += 1
info['tap_array'] += [op.tap_array[offset+idx]] info['tap_array'] += [op.tap_array[offset + idx]]
#input taps #input taps
for jdx in op.tap_array[offset+idx]: for jdx in op.tap_array[offset + idx]:
op_inputs += [op.inputs[i_offset]] op_inputs += [op.inputs[i_offset]]
i_offset += 1 i_offset += 1
#output taps #output taps
op_outputs += [op.outputs[o_offset]] op_outputs += [op.outputs[o_offset]]
o_offset+=1 o_offset += 1
#node inputs #node inputs
node_inputs += [inputs[ni_offset+idx]] node_inputs += [inputs[ni_offset + idx]]
else: else:
o_offset+=1 o_offset += 1
i_offset+=len(op.tap_array[offset+idx]) i_offset += len(op.tap_array[offset + idx])
offset += op.n_mit_sot offset += op.n_mit_sot
ni_offset += op.n_mit_sot ni_offset += op.n_mit_sot
for idx in xrange(op.info['n_sit_sot']): for idx in xrange(op.info['n_sit_sot']):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info['n_sit_sot'] += 1 info['n_sit_sot'] += 1
info['tap_array'] += [op.tap_array[offset+idx]] info['tap_array'] += [op.tap_array[offset + idx]]
#input taps #input taps
op_inputs += [op.inputs[i_offset]] op_inputs += [op.inputs[i_offset]]
i_offset += 1 i_offset += 1
#output taps #output taps
op_outputs += [op.outputs[o_offset]] op_outputs += [op.outputs[o_offset]]
o_offset+=1 o_offset += 1
#node inputs #node inputs
node_inputs += [inputs[ni_offset+idx]] node_inputs += [inputs[ni_offset + idx]]
else: else:
o_offset+=1 o_offset += 1
i_offset+=1 i_offset += 1
offset += op.n_sit_sot offset += op.n_sit_sot
ni_offset += op.n_sit_sot ni_offset += op.n_sit_sot
nit_sot_ins = [] nit_sot_ins = []
for idx in xrange(op.info['n_nit_sot']): for idx in xrange(op.info['n_nit_sot']):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info['n_nit_sot'] += 1 info['n_nit_sot'] += 1
op_outputs += [op.outputs[o_offset]] op_outputs += [op.outputs[o_offset]]
o_offset+=1 o_offset += 1
nit_sot_ins += [inputs[ni_offset+idx+op.n_shared_outs]] nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]]
else: else:
o_offset += 1 o_offset += 1
...@@ -647,14 +642,14 @@ def compress_outs(op, not_required, inputs): ...@@ -647,14 +642,14 @@ def compress_outs(op, not_required, inputs):
shared_ins = [] shared_ins = []
for idx in xrange(op.info['n_shared_outs']): for idx in xrange(op.info['n_shared_outs']):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset+idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info['n_shared_outs'] += 1 info['n_shared_outs'] += 1
op_outputs += [ op.outputs[o_offset]] op_outputs += [op.outputs[o_offset]]
o_offset +=1 o_offset += 1
op_inputs += [ op.inputs[i_offset]] op_inputs += [op.inputs[i_offset]]
i_offset += 1 i_offset += 1
shared_ins += [inputs[ni_offset+idx]] shared_ins += [inputs[ni_offset + idx]]
else: else:
o_offset += 1 o_offset += 1
i_offset += 1 i_offset += 1
...@@ -662,14 +657,15 @@ def compress_outs(op, not_required, inputs): ...@@ -662,14 +657,15 @@ def compress_outs(op, not_required, inputs):
node_inputs += nit_sot_ins node_inputs += nit_sot_ins
# other stuff # other stuff
op_inputs += op.inputs[i_offset:] op_inputs += op.inputs[i_offset:]
node_inputs += inputs[ni_offset+op.n_shared_outs+op.n_nit_sot:] node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot:]
if op.as_while: if op.as_while:
op_outputs += [op.outputs[o_offset]] op_outputs += [op.outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs)-1 map_old_new[o_offset] = len(op_outputs) - 1
#map_old_new[len(op_outputs)-1] = o_offset #map_old_new[len(op_outputs)-1] = o_offset
return (op_inputs, op_outputs, info, node_inputs, map_old_new) return (op_inputs, op_outputs, info, node_inputs, map_old_new)
def find_up(l_node, f_node): def find_up(l_node, f_node):
r""" r"""
Goes up in the graph and returns True if a node in nodes is found. Goes up in the graph and returns True if a node in nodes is found.
...@@ -682,7 +678,8 @@ def find_up(l_node, f_node): ...@@ -682,7 +678,8 @@ def find_up(l_node, f_node):
nodes = gof.graph.io_toposort(l_ins, l_outs) nodes = gof.graph.io_toposort(l_ins, l_outs)
return f_node in nodes return f_node in nodes
def reconstruct_graph(inputs, outputs, tag = None):
def reconstruct_graph(inputs, outputs, tag=None):
""" """
Different interface to clone, that allows you to pass inputs. Different interface to clone, that allows you to pass inputs.
Compared to clone, this method always replaces the inputs with Compared to clone, this method always replaces the inputs with
...@@ -691,7 +688,7 @@ def reconstruct_graph(inputs, outputs, tag = None): ...@@ -691,7 +688,7 @@ def reconstruct_graph(inputs, outputs, tag = None):
""" """
if tag is None: if tag is None:
tag = '' tag = ''
nw_inputs = [safe_new(x,tag) for x in inputs] nw_inputs = [safe_new(x, tag) for x in inputs]
givens = {} givens = {}
for nw_x, x in zip(nw_inputs, inputs): for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x givens[x] = nw_x
...@@ -700,9 +697,10 @@ def reconstruct_graph(inputs, outputs, tag = None): ...@@ -700,9 +697,10 @@ def reconstruct_graph(inputs, outputs, tag = None):
if isinstance(inp, theano.Constant): if isinstance(inp, theano.Constant):
givens[inp] = inp.clone() givens[inp] = inp.clone()
nw_outputs = clone( outputs, replace=givens) nw_outputs = clone(outputs, replace=givens)
return (nw_inputs, nw_outputs) return (nw_inputs, nw_outputs)
class scan_args(object): class scan_args(object):
"""Parses the inputs and outputs of scan in an easy to manipulate format""" """Parses the inputs and outputs of scan in an easy to manipulate format"""
def __init__(self, outer_inputs, outer_outputs, def __init__(self, outer_inputs, outer_outputs,
...@@ -720,8 +718,8 @@ class scan_args(object): ...@@ -720,8 +718,8 @@ class scan_args(object):
q = 0 q = 0
n_seqs = info['n_seqs'] n_seqs = info['n_seqs']
self.outer_in_seqs = outer_inputs[p:p+n_seqs] self.outer_in_seqs = outer_inputs[p:p + n_seqs]
self.inner_in_seqs = inner_inputs[q:q+n_seqs] self.inner_in_seqs = inner_inputs[q:q + n_seqs]
p += n_seqs p += n_seqs
q += n_seqs q += n_seqs
...@@ -729,46 +727,47 @@ class scan_args(object): ...@@ -729,46 +727,47 @@ class scan_args(object):
n_mit_sot = info['n_mit_sot'] n_mit_sot = info['n_mit_sot']
self.mit_mot_in_slices = info['tap_array'][:n_mit_mot] self.mit_mot_in_slices = info['tap_array'][:n_mit_mot]
self.mit_sot_in_slices = info['tap_array'][n_mit_mot:n_mit_mot+n_mit_sot] self.mit_sot_in_slices = info['tap_array'][
n_mit_mot:n_mit_mot + n_mit_sot]
n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices) n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices) n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)
iimm = inner_inputs[q:q+n_mit_mot_ins] iimm = inner_inputs[q:q + n_mit_mot_ins]
self.inner_in_mit_mot = [] self.inner_in_mit_mot = []
qq = 0 qq = 0
for sl in self.mit_mot_in_slices: for sl in self.mit_mot_in_slices:
self.inner_in_mit_mot.append(iimm[qq:qq+len(sl)]) self.inner_in_mit_mot.append(iimm[qq:qq + len(sl)])
qq += len(sl) qq += len(sl)
q += n_mit_mot_ins q += n_mit_mot_ins
iims = inner_inputs[q:q+n_mit_sot_ins] iims = inner_inputs[q:q + n_mit_sot_ins]
self.inner_in_mit_sot = [] self.inner_in_mit_sot = []
qq = 0 qq = 0
for sl in self.mit_sot_in_slices: for sl in self.mit_sot_in_slices:
self.inner_in_mit_sot.append(iims[qq:qq+len(sl)]) self.inner_in_mit_sot.append(iims[qq:qq + len(sl)])
qq += len(sl) qq += len(sl)
q += n_mit_sot_ins q += n_mit_sot_ins
self.outer_in_mit_mot = outer_inputs[p:p+n_mit_mot] self.outer_in_mit_mot = outer_inputs[p:p + n_mit_mot]
p += n_mit_mot p += n_mit_mot
self.outer_in_mit_sot = outer_inputs[p:p+n_mit_sot] self.outer_in_mit_sot = outer_inputs[p:p + n_mit_sot]
p += n_mit_sot p += n_mit_sot
n_sit_sot = info['n_sit_sot'] n_sit_sot = info['n_sit_sot']
self.outer_in_sit_sot = outer_inputs[p:p+n_sit_sot] self.outer_in_sit_sot = outer_inputs[p:p + n_sit_sot]
self.inner_in_sit_sot = inner_inputs[q:q+n_sit_sot] self.inner_in_sit_sot = inner_inputs[q:q + n_sit_sot]
p += n_sit_sot p += n_sit_sot
q += n_sit_sot q += n_sit_sot
n_shared_outs = info['n_shared_outs'] n_shared_outs = info['n_shared_outs']
self.outer_in_shared = outer_inputs[p:p+n_shared_outs] self.outer_in_shared = outer_inputs[p:p + n_shared_outs]
self.inner_in_shared = inner_inputs[q:q+n_shared_outs] self.inner_in_shared = inner_inputs[q:q + n_shared_outs]
p += n_shared_outs p += n_shared_outs
q += n_shared_outs q += n_shared_outs
n_nit_sot = info['n_nit_sot'] n_nit_sot = info['n_nit_sot']
self.outer_in_nit_sot = outer_inputs[p:p+n_nit_sot] self.outer_in_nit_sot = outer_inputs[p:p + n_nit_sot]
p += n_nit_sot p += n_nit_sot
self.outer_in_non_seqs = outer_inputs[p:] self.outer_in_non_seqs = outer_inputs[p:]
...@@ -780,40 +779,39 @@ class scan_args(object): ...@@ -780,40 +779,39 @@ class scan_args(object):
self.mit_mot_out_slices = info['mit_mot_out_slices'] self.mit_mot_out_slices = info['mit_mot_out_slices']
n_mit_mot_outs = info['n_mit_mot_outs'] n_mit_mot_outs = info['n_mit_mot_outs']
self.outer_out_mit_mot = outer_outputs[p:p+n_mit_mot] self.outer_out_mit_mot = outer_outputs[p:p + n_mit_mot]
iomm = inner_outputs[q:q+n_mit_mot_outs] iomm = inner_outputs[q:q + n_mit_mot_outs]
self.inner_out_mit_mot = [] self.inner_out_mit_mot = []
qq = 0 qq = 0
for sl in self.mit_mot_out_slices: for sl in self.mit_mot_out_slices:
self.inner_out_mit_mot.append(iomm[qq:qq+len(sl)]) self.inner_out_mit_mot.append(iomm[qq:qq + len(sl)])
qq += len(sl) qq += len(sl)
p += n_mit_mot p += n_mit_mot
q += n_mit_mot_outs q += n_mit_mot_outs
self.outer_out_mit_sot = outer_outputs[p:p+n_mit_sot] self.outer_out_mit_sot = outer_outputs[p:p + n_mit_sot]
self.inner_out_mit_sot = inner_outputs[q:q+n_mit_sot] self.inner_out_mit_sot = inner_outputs[q:q + n_mit_sot]
p += n_mit_sot p += n_mit_sot
q += n_mit_sot q += n_mit_sot
self.outer_out_sit_sot = outer_outputs[p:p+n_sit_sot] self.outer_out_sit_sot = outer_outputs[p:p + n_sit_sot]
self.inner_out_sit_sot = inner_outputs[q:q+n_sit_sot] self.inner_out_sit_sot = inner_outputs[q:q + n_sit_sot]
p += n_sit_sot p += n_sit_sot
q += n_sit_sot q += n_sit_sot
self.outer_out_nit_sot = outer_outputs[p:p+n_nit_sot] self.outer_out_nit_sot = outer_outputs[p:p + n_nit_sot]
self.inner_out_nit_sot = inner_outputs[q:q+n_nit_sot] self.inner_out_nit_sot = inner_outputs[q:q + n_nit_sot]
p += n_nit_sot p += n_nit_sot
q += n_nit_sot q += n_nit_sot
self.outer_out_shared = outer_outputs[p:p+n_shared_outs] self.outer_out_shared = outer_outputs[p:p + n_shared_outs]
self.inner_out_shared = inner_outputs[q:q+n_shared_outs] self.inner_out_shared = inner_outputs[q:q + n_shared_outs]
p += n_shared_outs p += n_shared_outs
q += n_shared_outs q += n_shared_outs
self.other_info = dict() self.other_info = dict()
for k in ('truncate_gradient', 'name', 'mode', 'inplace', for k in ('truncate_gradient', 'name', 'mode', 'inplace',
'gpu','as_while', 'profile'): 'gpu', 'as_while', 'profile'):
self.other_info[k] = info[k] self.other_info[k] = info[k]
inner_inputs = property(lambda self: (self.inner_in_seqs + inner_inputs = property(lambda self: (self.inner_in_seqs +
...@@ -844,7 +842,8 @@ class scan_args(object): ...@@ -844,7 +842,8 @@ class scan_args(object):
self.outer_out_nit_sot + self.outer_out_nit_sot +
self.outer_out_shared)) self.outer_out_shared))
info = property(lambda self: dict(n_seqs=len(self.outer_in_seqs), info = property(lambda self: dict(
n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot), n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot), n_mit_sot=len(self.outer_in_mit_sot),
tap_array=(self.mit_mot_in_slices + tap_array=(self.mit_mot_in_slices +
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论