提交 2e5fcea2 authored 作者: lamblin's avatar lamblin

Merge pull request #1054 from pascanur/new_optimizations_scan

New optimizations scan
...@@ -1375,18 +1375,18 @@ class Scan(PureOp): ...@@ -1375,18 +1375,18 @@ class Scan(PureOp):
def compute_gradient(y, g_y): def compute_gradient(y, g_y):
if 'int' in str(g_y.dtype): if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y " raise TypeError("Gradients may never be integers but g_y "
"has type "+str(g_y.type)) "has type " + str(g_y.type))
wrt = [x for x in theano.gof.graph.inputs([y]) wrt = [x for x in theano.gof.graph.inputs([y])
if x in diff_inputs] if x in diff_inputs]
grads = gradient.grad( grads = gradient.grad(
cost = None, cost=None,
known_grads = {y : g_y }, known_grads={y: g_y},
wrt=wrt, consider_constant=wrt, wrt=wrt, consider_constant=wrt,
disconnected_inputs='ignore', disconnected_inputs='ignore',
return_disconnected='None') return_disconnected='None')
gmp = dict(zip(wrt, grads)) gmp = dict(zip(wrt, grads))
rval = [gmp.get(p, None) for p in diff_inputs] rval = [gmp.get(p, None) for p in diff_inputs]
return rval return rval
dC_dinps_t = [None for inp in diff_inputs] dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs] disconnected_dC_dinps_t = [True for inp in diff_inputs]
...@@ -1727,7 +1727,7 @@ class Scan(PureOp): ...@@ -1727,7 +1727,7 @@ class Scan(PureOp):
node = outs[0].owner node = outs[0].owner
for idx in xrange(self.n_shared_outs): for idx in xrange(self.n_shared_outs):
disconnected = True disconnected = True
connected_flags = self.connection_pattern(node)[idx+start] connected_flags = self.connection_pattern(node)[idx + start]
for dC_dout, connected in zip(dC_douts, connected_flags): for dC_dout, connected in zip(dC_douts, connected_flags):
if (not isinstance(dC_dout.type, DisconnectedType) and if (not isinstance(dC_dout.type, DisconnectedType) and
connected): connected):
......
...@@ -103,20 +103,35 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -103,20 +103,35 @@ def remove_constants_and_unused_inputs_scan(node):
except TypeError: except TypeError:
pass pass
elif op_ins[idx] in all_ins: elif op_ins[idx] in all_ins:
nw_inner += [op_ins[idx]] # Check for identical other sequence
nw_outer += [node.inputs[idx + 1]] identical_seqs = [x for x in nw_outer
if scan_utils.equal_computations(
[x], [node.inputs[idx + 1]])]
if identical_seqs:
index = node.inputs.index(identical_seqs[0]) - 1
givens[op_ins[idx]] = op_ins[index]
else:
nw_inner += [op_ins[idx]]
nw_outer += [node.inputs[idx + 1]]
nw_n_seqs = len(nw_inner) nw_n_seqs = len(nw_inner)
# Add outputs stuff # Add outputs stuff
nw_inner += out_stuff_inner nw_inner += out_stuff_inner
nw_outer += out_stuff_outer nw_outer += out_stuff_outer
# Look through non sequences # Look through non sequences
for nw_in, nw_out in zip(non_seqs, outer_non_seqs): for idx, (nw_in, nw_out) in enumerate(zip(non_seqs, outer_non_seqs)):
if isinstance(nw_out, tensor.Constant): if isinstance(nw_out, tensor.Constant):
givens[nw_in] = nw_out.clone() givens[nw_in] = nw_out.clone()
elif nw_in in all_ins: elif nw_in in all_ins:
nw_inner += [nw_in] identical_non_seqs = [x for x in outer_non_seqs[:idx]
nw_outer += [nw_out] if scan_utils.equal_computations(
[x], [nw_out])]
if identical_non_seqs:
index = outer_non_seqs.index(identical_non_seqs[0])
givens[nw_in] = non_seqs[index]
else:
nw_inner += [nw_in]
nw_outer += [nw_out]
if len(nw_inner) != len(op_ins): if len(nw_inner) != len(op_ins):
op_outs = scan_utils.clone(op_outs, replace=givens) op_outs = scan_utils.clone(op_outs, replace=givens)
...@@ -129,17 +144,6 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -129,17 +144,6 @@ def remove_constants_and_unused_inputs_scan(node):
else: else:
return False return False
scan_seqopt = theano.gof.SequenceDB()
# We run before blas opt at 1.7 and specialize 2.0
# but after stabilize at 1.5. Should we put it before stabilize?
optdb.register('scan_seqopt', scan_seqopt, 1.6, 'fast_run', 'scan')
scan_seqopt.register('scanOp_remove_constants_and_unused_inputs',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
5,
'fast_run',
'scan')
# This is a global opt for historical reason # This is a global opt for historical reason
# It should be possible to change it to a local opt. # It should be possible to change it to a local opt.
...@@ -172,26 +176,19 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -172,26 +176,19 @@ class PushOutNonSeqScan(gof.Optimizer):
replace_with_out = [] replace_with_out = []
op = node.op op = node.op
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
st = op.n_seqs inner_non_seqs = op.inner_non_seqs(clean_inputs)
st += int(numpy.sum([len(x) for x in outer_non_seqs = op.outer_non_seqs(node.inputs)
op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]])) inner_seqs = op.inner_seqs(clean_inputs)
st += op.n_sit_sot outer_seqs = op.outer_seqs(node.inputs)
st += op.n_shared_outs assert len(inner_non_seqs) == len(outer_non_seqs)
non_seqs = clean_inputs[st:] assert len(inner_seqs) == len(outer_seqs)
st = (op.n_seqs +
op.n_mit_mot +
op.n_mit_sot +
op.n_sit_sot +
op.n_nit_sot +
op.n_shared_outs + 1)
outer_non_seqs = node.inputs[st:]
assert len(non_seqs) == len(outer_non_seqs)
while changed and counts < max_iterations: while changed and counts < max_iterations:
counts += 1 counts += 1
changed = False changed = False
for nd in local_fgraph.toposort(): for nd in local_fgraph.toposort():
if (numpy.all([(x in non_seqs) or if (numpy.all([(x in inner_non_seqs) or
(x.owner in to_remove) or (x.owner in to_remove) or
isinstance(x, tensor.Constant) isinstance(x, tensor.Constant)
for x in nd.inputs]) and for x in nd.inputs]) and
...@@ -208,8 +205,9 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -208,8 +205,9 @@ class PushOutNonSeqScan(gof.Optimizer):
to_remove.append(nd) to_remove.append(nd)
outside_ins = [] outside_ins = []
for x in nd.inputs: for x in nd.inputs:
if x in non_seqs: if x in inner_non_seqs:
outside_ins += [outer_non_seqs[non_seqs.index(x)]] _idx = inner_non_seqs.index(x)
outside_ins += [outer_non_seqs[_idx]]
elif x in to_replace: elif x in to_replace:
outside_ins += [ outside_ins += [
replace_with_out[to_replace.index(x)]] replace_with_out[to_replace.index(x)]]
...@@ -297,6 +295,8 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -297,6 +295,8 @@ class PushOutNonSeqScan(gof.Optimizer):
*shape) *shape)
# We need to add one extra dimension to the outputs # We need to add one extra dimension to the outputs
# because the scan op expects for a tensor3, to which an
# subtensor is applied that takes only the last element
if replace_with: if replace_with:
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
replace_with.items(), replace_with.items(),
...@@ -307,11 +307,200 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -307,11 +307,200 @@ class PushOutNonSeqScan(gof.Optimizer):
return False return False
scan_seqopt.register('scanOp_pushout_nonseqs_ops', # This is a global opt for historical reason
PushOutNonSeqScan(), # It should be possible to change it to a local opt.
1, class PushOutSeqScan(gof.Optimizer):
'fast_run',
'scan') def __init__(self):
gof.Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op,
scan_op.Scan)]
for node in nodelist:
self.process_node(fgraph, node)
def process_node(self, fgraph, node):
# this flag tells if there was any change during the last iterations
changed = True
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs)
max_iterations = 2 * len(local_fgraph.toposort()) + 3
counts = 0
to_remove = []
to_replace = []
replace_with_in = []
replace_with_out = []
op = node.op
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs)
outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
assert len(inner_seqs) == len(outer_seqs)
while changed and counts < max_iterations:
counts += 1
changed = False
for nd in local_fgraph.toposort():
if (isinstance(nd.op, theano.tensor.Elemwise) and
numpy.all([(x in inner_non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant) or
(x in inner_seqs)
for x in nd.inputs]) and
not nd in to_remove):
to_remove.append(nd)
outside_ins = []
for x in nd.inputs:
if x in inner_non_seqs:
_idx = inner_non_seqs.index(x)
outside_ins += [outer_non_seqs[_idx]]
elif x in inner_seqs:
outside_ins += [outer_seqs[inner_seqs.index(x)]]
elif x in to_replace:
outside_ins += [replace_with_out[\
to_replace.index(x)]]
elif isinstance(x, theano.Constant):
outside_ins += [x.clone()]
else:
raise Exception(
('Error in the `scan_pushout_non_seq_'
'operations`. The optimization tries '
'to move some computation fron scan '
'which is not allowed to move. Report '
'this on theano-users list'), x)
nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
replace_with_out += [nw_outer_node.outputs[idx]]
changed = True
elif (isinstance(nd.op, theano.tensor.DimShuffle) and
(nd.inputs[0] in inner_seqs or
nd.inputs[0].owner in to_remove) and
not nd in to_remove):
to_remove.append(nd)
x = nd.inputs[0]
if x in inner_seqs:
outside_ins = outer_seqs[inner_seqs.index(x)]
elif x in to_replace:
outside_ins = replace_with_out[to_replace.index(x)]
new_ord = (0,)
for old_ord in nd.op.new_order:
if isinstance(old_ord, int):
new_ord += (old_ord + 1,)
else:
new_ord += (old_ord,)
new_outer = outside_ins.dimshuffle(new_ord)
y = nd.outputs[0]
y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
replace_with_out += [new_outer]
changed = True
if counts >= max_iterations:
raise Exception('Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!')
# We need to check all candidate replacements and choose those that
# make sense for us
# Step 1. which elements of `to_replace` are used by remaining
# components of the inner function
clean_to_replace = []
clean_replace_with_in = []
clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph.toposort()
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
for idx, out in enumerate(to_replace):
if out in to_keep and out.owner not in existent_nodes:
clean_to_replace += [out]
clean_replace_with_in += [replace_with_in[idx]]
clean_replace_with_out += [replace_with_out[idx]]
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
givens = {}
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace,
clean_replace_with_in,
clean_replace_with_out):
if isinstance(repl_out, theano.Constant):
repl_in = repl_out.clone()
else:
nw_inner += [repl_in]
nw_outer += [repl_out]
givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs,
replace=givens)
_op_ins = nw_inner + clean_inputs
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
# Reconstruct node
nw_info = op.info.copy()
nw_info['n_seqs'] += len(nw_inner)
nwScan = scan_op.Scan(op_ins, op_outs, nw_info)
nw_node = nwScan.make_node(* (node.inputs[:1] + nw_outer +
node.inputs[1:]))
fgraph.replace_all_validate_remove(
zip(node.outputs, nw_node.outputs),
remove=[node],
reason='scan_push_computation_out')
return True
elif (to_keep == [] and
not op.as_while and
not op.outer_mitmot(node)):
# Nothing in the inner graph should be kept
replace_with = gof.python25.OrderedDict()
for idx, out in enumerate(to_replace):
if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)]
_y = replace_with_out[idx]
ls = local_fgraph.outputs
if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node)[odx]
st = abs(numpy.min(op.mitsot_taps()))
y = tensor.set_subtensor(inp[st:], _y)
elif out in op.inner_sitsot_outs(ls):
odx = op.inner_sitsot_outs(ls).index(out)
inp = op.outer_sitsot(node)[odx]
y = tensor.set_subtensor(inp[1:], _y)
elif out in op.inner_nitsot_outs(ls):
y = _y
else:
y = _y[-1]
replace_with[x] = y
# We need to add one extra dimension to the outputs
if replace_with:
fgraph.replace_all_validate_remove(
replace_with.items(),
remove=[node],
reason='scan_push_seq_computation_out')
else:
return False
class ScanInplaceOptimizer(Optimizer): class ScanInplaceOptimizer(Optimizer):
...@@ -373,14 +562,6 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -373,14 +562,6 @@ class ScanInplaceOptimizer(Optimizer):
# Failed moving output to be comptued inplace # Failed moving output to be comptued inplace
pass pass
optdb.register('scanOp_make_inplace',
ScanInplaceOptimizer(typeConstructor=None,
gpu_flag=False),
75,
'fast_run',
'inplace',
'scan')
class ScanSaveMem(gof.Optimizer): class ScanSaveMem(gof.Optimizer):
""" Graph Optimizer that reduces scan memory consumption """ """ Graph Optimizer that reduces scan memory consumption """
...@@ -858,15 +1039,6 @@ class ScanSaveMem(gof.Optimizer): ...@@ -858,15 +1039,6 @@ class ScanSaveMem(gof.Optimizer):
if not hasattr(node.op, '_scan_savemem_visited'): if not hasattr(node.op, '_scan_savemem_visited'):
self.process_node(fgraph, node) self.process_node(fgraph, node)
# Just before specialize to have the other optimization
# like constant folding being applied
# This don't introduce inplace.
scan_seqopt.register('scanOp_save_mem',
ScanSaveMem(),
4,
'fast_run',
'scan')
class ScanMerge(gof.Optimizer): class ScanMerge(gof.Optimizer):
""" Graph Optimizer that merges different scan ops """ """ Graph Optimizer that merges different scan ops """
...@@ -1048,16 +1220,6 @@ class ScanMerge(gof.Optimizer): ...@@ -1048,16 +1220,6 @@ class ScanMerge(gof.Optimizer):
reason='scan_merge') reason='scan_merge')
# after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later.
scan_seqopt.register('scanOp_merge',
ScanMerge(),
2,
'fast_run',
'scan')
def has_duplicates(l): def has_duplicates(l):
"""returns true if l has any duplicates (according to __eq__).""" """returns true if l has any duplicates (according to __eq__)."""
return len(set(l)) < len(l) return len(set(l)) < len(l)
...@@ -1179,10 +1341,48 @@ def scan_merge_inouts(node): ...@@ -1179,10 +1341,48 @@ def scan_merge_inouts(node):
seen.append((i, o)) seen.append((i, o))
return o return o
def map_nitsot_out(i, o, sh, seen):
for p, (si, so, ssh) in enumerate(seen):
if equal_computations([i], [si], left, right):
if equal_computations([sh], [ssh]):
return so
try:
vsh = int(opt.get_constant_value(sh))
vssh = int(opt.get_constant_value(ssh))
except TypeError:
return o
if vsh == vssh:
return so
elif vsh > vssh:
seen[p] = (i, o, sh)
return o
else:
return so[:vsh]
seen.append((i, o, sh))
return o
seen = [] seen = []
na.outer_out_nit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_nit_sot, shapes = []
na.outer_out_nit_sot)] for x in na.outer_in_nit_sot:
if x.ndim > 0:
if hasattr(node.fgraph, 'shape_feature'):
shapes.append(
node.fgraph.shape_feature.shape_of[x][0])
else:
shapes.append(x.shape[0])
else:
# If x is a scalar, then it means its value is the number of
# items scan is supposed to store for this nit_sot sequence
shapes.append(x)
tmp = [map_nitsot_out(i, o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
na.outer_out_nit_sot = [map_nitsot_out(i, o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
seen = [] seen = []
na.outer_out_sit_sot = [map_out(i, o, seen) na.outer_out_sit_sot = [map_out(i, o, seen)
...@@ -1209,8 +1409,288 @@ def scan_merge_inouts(node): ...@@ -1209,8 +1409,288 @@ def scan_merge_inouts(node):
return na.outer_outputs return na.outer_outputs
scan_seqopt.register('scanOp_merge_inouts',
opt.in2out(scan_merge_inouts, ignore_newtrees=True), class PushOutDot1(gof.Optimizer):
3, """Graph optimizer for Scan(makes it run inplace)"""
'fast_run', def __init__(self):
'scan') Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler())
def apply(self, fgraph):
nodes = fgraph.toposort()
scan_nodes = [x for x in nodes if (isinstance(x.op, scan_op.Scan))]
for node in scan_nodes:
self.apply_opt(fgraph, node)
def apply_opt(self, fgraph, node):
# Replace pattern of the form
# x[t] = x[t-1] + dot(seq[t], value)
# with Sequence.reshape((-1, seq.shape[2])) \dot Value
# When seq[t] is a vector/matrix and `value` is a matrix
# Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot
op = node.op
sitsot_ins = op.inner_sitsot(op.inputs)
sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_sitsot = op.outer_sitsot_outs(node)
seqs = op.inner_seqs(op.inputs)
for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):
if (out.owner and
isinstance(out.owner.op, theano.tensor.Elemwise) and
isinstance(out.owner.op.scalar_op, theano.scalar.Add) and
inp in out.owner.inputs and
len(outer_out.clients) == 1 and
not isinstance(outer_out.clients[0][0], str) and
isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor)
and outer_out.clients[0][0].op.idx_list == (-1,)):
x = out.owner.inputs[0]
if x == inp:
x = out.owner.inputs[1]
# We need to check if x is the result of an outer product
if (x.owner and
isinstance(x.owner.op, theano.tensor.Dot) and
x.owner.inputs[0].ndim == 2 and
x.owner.inputs[1].ndim == 2):
# We need to check if any of the inputs are a sequence
inp1 = x.owner.inputs[0]
inp2 = x.owner.inputs[1]
if inp1 in seqs or inp2 in seqs:
new_scan_out = inp2
if inp2 in seqs:
new_scan_out = inp1
idx = sitsot_outs.index(out)
# We've found our pattern and need to construct a new
# scan node to replace this one. For this we need to
# replace the sit_sot output with a nit_sot output
# First let us split all arguments according to their
# corresponding categories
inner_seqs = op.inner_seqs(op.inputs)
outer_seqs = op.outer_seqs(node)
inner_mitmot = op.inner_mitmot(op.inputs)
outer_mitmot = op.outer_mitmot(node)
inner_mitmot_outs = op.inner_mitmot_outs(op.outputs)
inner_mitsot = op.inner_mitsot(op.inputs)
outer_mitsot = op.outer_mitsot(node)
inner_mitsot_outs = op.inner_mitsot_outs(op.outputs)
inner_sitsot = op.inner_sitsot(op.inputs)
outer_sitsot = op.outer_sitsot(node)
inner_sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_nitsot = op.outer_nitsot(node)
inner_nitsot_outs = op.inner_nitsot_outs(op.outputs)
inner_shared = op.inner_shared(op.inputs)
outer_shared = op.outer_shared(node)
inner_shared_outs = op.inner_shared_outs(op.outputs)
inner_non_seqs = op.inner_non_seqs(op.inputs)
outer_non_seqs = op.outer_non_seqs(node)
new_info = op.info.copy()
st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info['tap_array'] = (new_info['tap_array'][:st + idx] +
new_info['tap_array'][st + idx + 1:])
new_info['n_sit_sot'] -= 1
new_info['n_nit_sot'] += 1
inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1:]
outer_sitsot = outer_sitsot[:idx] + outer_sitsot[idx + 1:]
inner_sitsot_outs = inner_sitsot_outs[:idx] +\
inner_sitsot_outs[idx + 1:]
# add n_steps as the length
inner_nitsot_outs.append(new_scan_out)
_new_inner_inps = (inner_seqs +
inner_mitmot +
inner_mitsot +
inner_sitsot +
inner_shared +
inner_non_seqs)
_new_inner_outs = (inner_mitmot_outs +
inner_mitsot_outs +
inner_sitsot_outs +
inner_nitsot_outs +
inner_shared_outs)
new_inner_inps, new_inner_outs =\
scan_utils.reconstruct_graph(
_new_inner_inps, _new_inner_outs)
new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
new_info)
_scan_inputs = ([node.inputs[0]] +
outer_seqs +
outer_mitmot +
outer_mitsot +
outer_sitsot +
outer_shared +
outer_nitsot +
[node.inputs[0]] +
outer_non_seqs)
new_outs = new_op(*_scan_inputs)
# We need now to pair correctly the new outputs with the
# old ones
outer_mitmot_outs = new_op.outer_mitmot_outs(new_outs)
outer_mitsot_outs = new_op.outer_mitsot_outs(new_outs)
outer_sitsot_outs = new_op.outer_sitsot_outs(new_outs)
outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs)
outer_shared_outs = new_op.outer_shared_outs(new_outs)
_val = outer_nitsot_outs[-1]
outer_nitsot_outs = outer_nitsot_outs[:-1]
if inp1 in seqs:
_out_seq = op.outer_seqs(node)[seqs.index(inp1)]
# We need to clip the seq to the number of steps
_out_seq = _out_seq[:node.inputs[0]]
sh0 = _out_seq.shape[0]
sh1 = _out_seq.shape[1]
sh2 = _out_seq.shape[2]
out_seq = _out_seq.dimshuffle(1, 0, 2)
out_seq = out_seq.reshape((sh1, sh0 * sh2))
sh0 = _val.shape[0]
sh1 = _val.shape[1]
sh2 = _val.shape[2]
val = _val.reshape((sh0 * sh1, sh2))
new_out = tensor.dot(out_seq, val)
else:
_out_seq = op.outer_seqs(node)[seqs.index(inp2)]
out_seq = _out_seq.reshape(
(_out_seq.shape[0] * _out_seq.shape[1],
_out_seq.shape[2]))
val = _val.dimshuffle(1, 0, 2).reshape(
(_val.shape[1],
_val.shape[0] * _val.shape[2]))
new_out = tensor.dot(val, out_seq)
pos = node.outputs.index(outer_out)
old_new = zip(node.outputs[:pos], new_outs[:pos])
old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out))
old_new += zip(node.outputs[pos+1:], new_outs[pos:])
fgraph.replace_all_validate_remove(old_new,
remove = [node],
reason='PushOutDot1')
# I've added an equilibrium because later scan optimization in the sequence
# can make it such that earlier optimizations should apply. However, in
# general I do not expect the sequence to run more then once
scan_eqopt1 = theano.gof.EquilibriumDB()
scan_seqopt1 = theano.gof.SequenceDB()
scan_eqopt2 = theano.gof.EquilibriumDB()
scan_seqopt2 = theano.gof.EquilibriumDB()
# We run before blas opt at 1.7 and specialize 2.0
# but after stabilize at 1.5. Should we put it before stabilize?
optdb.register('scan_eqopt1', scan_eqopt1, .1, 'fast_run', 'scan')
optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan')
optdb.register('scanOp_make_inplace',
ScanInplaceOptimizer(typeConstructor=None,
gpu_flag=False),
75,
'fast_run',
'inplace',
'scan')
scan_eqopt2.register(
'all_scan_opts', scan_seqopt2, 1, 'fast_run', 'scan')
scan_eqopt1.register(
'all_pushout_opt', scan_seqopt1, 1, 'fast_run', 'scan')
scan_seqopt1.register('scanOp_remove_constants_and_unused_inputs0',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
1,
'fast_run',
'scan')
scan_seqopt1.register('scanOp_pushout_nonseqs_ops',
PushOutNonSeqScan(),
2,
'fast_run',
'scan')
scan_seqopt1.register('scanOp_pushout_seqs_ops',
PushOutSeqScan(),
3,
'fast_run',
'scan')
scan_seqopt1.register('scan_pushout_dot1',
PushOutDot1(),
4,
'fast_run',
'more_mem',
'scan')
scan_seqopt2.register('constant_folding_for_scan2',
opt.in2out(tensor.opt.constant_folding,
ignore_newtrees=True),
1,
'fast_run',
'scan')
scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs0',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
2,
'fast_run',
'scan')
# after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later.
scan_seqopt2.register('scanOp_merge',
ScanMerge(),
4,
'fast_run',
'scan')
# After Merge optimization
scan_seqopt2.register('scanop_remove_constants_and_unused_inputs2',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
5,
'fast_run',
'scan')
scan_seqopt2.register('scanOp_merge_inouts',
opt.in2out(scan_merge_inouts, ignore_newtrees=True),
6,
'fast_run',
'scan')
# Just before specialize to have the other optimization
# like constant folding being applied
# This don't introduce inplace.
scan_seqopt2.register('scanOp_save_mem',
ScanSaveMem(),
7,
'fast_run',
'scan')
# After everything else
scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs3',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
8,
'fast_run',
'scan')
...@@ -191,8 +191,6 @@ def get_updates_and_outputs(ls): ...@@ -191,8 +191,6 @@ def get_updates_and_outputs(ls):
this function know how to put it in that order? this function know how to put it in that order?
""" """
def is_outputs(elem): def is_outputs(elem):
if (isinstance(elem, (list, tuple)) and if (isinstance(elem, (list, tuple)) and
all([isinstance(x, theano.Variable) for x in elem])): all([isinstance(x, theano.Variable) for x in elem])):
...@@ -206,7 +204,7 @@ def get_updates_and_outputs(ls): ...@@ -206,7 +204,7 @@ def get_updates_and_outputs(ls):
# Make sure the updates will be applied in a deterministic order # Make sure the updates will be applied in a deterministic order
if not isinstance(elem, gof.python25.OrderedDict): if not isinstance(elem, gof.python25.OrderedDict):
warnings.warn("Expected OrderedDict or OrderedUpdates, got "\ warnings.warn("Expected OrderedDict or OrderedUpdates, got "\
+str(type(elem))+". This can make your script non-" + str(type(elem)) + ". This can make your script non-"
"deterministic.") "deterministic.")
return True return True
# Dictionaries can be given as lists of tuples # Dictionaries can be given as lists of tuples
...@@ -253,7 +251,6 @@ def get_updates_and_outputs(ls): ...@@ -253,7 +251,6 @@ def get_updates_and_outputs(ls):
'values, you can use `tensor.constant` to turn them into ' 'values, you can use `tensor.constant` to turn them into '
'Theano variables.') 'Theano variables.')
if is_outputs(ls): if is_outputs(ls):
return None, _list(ls), OrderedDict() return None, _list(ls), OrderedDict()
if is_updates(ls): if is_updates(ls):
...@@ -389,7 +386,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -389,7 +386,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
elif (isinstance(dx, tensor.Constant) and elif (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)): isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and if not (numpy.all(dx.data == dy.data) and
dx.dtype == dy.dtype and dx.type.dtype == dy.type.dtype and
dx.data.shape == dy.data.shape): dx.data.shape == dy.data.shape):
return False return False
else: else:
...@@ -413,7 +410,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -413,7 +410,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
if (isinstance(dx, tensor.Constant) and if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)): isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and if not (numpy.all(dx.data == dy.data) and
dx.dtype == dy.dtype and dx.type.dtype == dy.type.dtype and
dx.data.shape == dy.data.shape): dx.data.shape == dy.data.shape):
return False return False
else: else:
......
...@@ -2346,9 +2346,7 @@ class T_Scan(unittest.TestCase): ...@@ -2346,9 +2346,7 @@ class T_Scan(unittest.TestCase):
# this new assert is here to test if scan_merging works .. # this new assert is here to test if scan_merging works ..
nb_scan = len([n for n in topo nb_scan = len([n for n in topo
if isinstance(n.op, theano.scan_module.scan_op.Scan)]) if isinstance(n.op, theano.scan_module.scan_op.Scan)])
# For this to work we need an optimization that it will be pushed in self.assertTrue(nb_scan == 1)
# a new pull request
self.assertTrue(nb_scan == 2)
nb_shape_i = len([n for n in topo nb_shape_i = len([n for n in topo
if isinstance(n.op, theano.tensor.opt.Shape_i)]) if isinstance(n.op, theano.tensor.opt.Shape_i)])
if theano.config.mode != 'FAST_COMPILE': if theano.config.mode != 'FAST_COMPILE':
...@@ -2364,7 +2362,8 @@ class T_Scan(unittest.TestCase): ...@@ -2364,7 +2362,8 @@ class T_Scan(unittest.TestCase):
sx, upx = theano.scan(sum, sequences=[x]) sx, upx = theano.scan(sum, sequences=[x])
sy, upy = theano.scan(sum, sequences=[y]) sy, upy = theano.scan(sum, sequences=[y])
f = theano.function([x, y], [sx, sy], mode=mode_with_opt) f = theano.function([x, y], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = filter(lambda n: isinstance( scans = filter(lambda n: isinstance(
n.op, theano.scan_module.scan_op.Scan), topo) n.op, theano.scan_module.scan_op.Scan), topo)
...@@ -2373,7 +2372,8 @@ class T_Scan(unittest.TestCase): ...@@ -2373,7 +2372,8 @@ class T_Scan(unittest.TestCase):
sx, upx = theano.scan(sum, sequences=[x], n_steps=2) sx, upx = theano.scan(sum, sequences=[x], n_steps=2)
sy, upy = theano.scan(sum, sequences=[y], n_steps=3) sy, upy = theano.scan(sum, sequences=[y], n_steps=3)
f = theano.function([x, y], [sx, sy], mode=mode_with_opt) f = theano.function([x, y], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = filter(lambda n: isinstance( scans = filter(lambda n: isinstance(
n.op, theano.scan_module.scan_op.Scan), topo) n.op, theano.scan_module.scan_op.Scan), topo)
...@@ -2382,7 +2382,8 @@ class T_Scan(unittest.TestCase): ...@@ -2382,7 +2382,8 @@ class T_Scan(unittest.TestCase):
sx, upx = theano.scan(sum, sequences=[x], n_steps=4) sx, upx = theano.scan(sum, sequences=[x], n_steps=4)
sy, upy = theano.scan(sum, sequences=[y], n_steps=4) sy, upy = theano.scan(sum, sequences=[y], n_steps=4)
f = theano.function([x, y], [sx, sy], mode=mode_with_opt) f = theano.function([x, y], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = filter(lambda n: isinstance( scans = filter(lambda n: isinstance(
n.op, theano.scan_module.scan_op.Scan), topo) n.op, theano.scan_module.scan_op.Scan), topo)
...@@ -2391,7 +2392,8 @@ class T_Scan(unittest.TestCase): ...@@ -2391,7 +2392,8 @@ class T_Scan(unittest.TestCase):
sx, upx = theano.scan(sum, sequences=[x]) sx, upx = theano.scan(sum, sequences=[x])
sy, upy = theano.scan(sum, sequences=[x]) sy, upy = theano.scan(sum, sequences=[x])
f = theano.function([x], [sx, sy], mode=mode_with_opt) f = theano.function([x], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = filter(lambda n: scans = filter(lambda n:
isinstance(n.op, theano.scan_module.scan_op.Scan), topo) isinstance(n.op, theano.scan_module.scan_op.Scan), topo)
...@@ -2401,7 +2403,7 @@ class T_Scan(unittest.TestCase): ...@@ -2401,7 +2403,7 @@ class T_Scan(unittest.TestCase):
sy, upy = theano.scan(sum, sequences=[x], mode='FAST_COMPILE') sy, upy = theano.scan(sum, sequences=[x], mode='FAST_COMPILE')
f = theano.function([x], [sx, sy], f = theano.function([x], [sx, sy],
mode=mode_with_opt) mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = filter(lambda n: scans = filter(lambda n:
isinstance(n.op, theano.scan_module.scan_op.Scan), topo) isinstance(n.op, theano.scan_module.scan_op.Scan), topo)
...@@ -2410,7 +2412,8 @@ class T_Scan(unittest.TestCase): ...@@ -2410,7 +2412,8 @@ class T_Scan(unittest.TestCase):
sx, upx = theano.scan(sum, sequences=[x]) sx, upx = theano.scan(sum, sequences=[x])
sy, upy = theano.scan(sum, sequences=[x], truncate_gradient=1) sy, upy = theano.scan(sum, sequences=[x], truncate_gradient=1)
f = theano.function([x], [sx, sy], mode=mode_with_opt) f = theano.function([x], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = filter(lambda n: scans = filter(lambda n:
isinstance(n.op, theano.scan_module.scan_op.Scan), topo) isinstance(n.op, theano.scan_module.scan_op.Scan), topo)
...@@ -2820,12 +2823,12 @@ class T_Scan(unittest.TestCase): ...@@ -2820,12 +2823,12 @@ class T_Scan(unittest.TestCase):
vx = numpy.zeros((50,), dtype=theano.config.floatX) vx = numpy.zeros((50,), dtype=theano.config.floatX)
vx[23] = 4 vx[23] = 4
out, out2 = f(vx) out, out2 = f(vx)
print 'len_out', len(out)
assert len(out) == 24 assert len(out) == 24
assert numpy.all(out2 == vx + 2) assert numpy.all(out2 == vx + 2)
lssc = [x for x in f.maker.fgraph.toposort() lssc = [x for x in f.maker.fgraph.toposort()
if isinstance(x.op, theano.scan_module.scan_op.Scan)] if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert len(lssc) == 2 # One scan node gets optimnized out
assert len(lssc) == 1
@dec.knownfailureif(True, @dec.knownfailureif(True,
("This test fails because not typed outputs_info " ("This test fails because not typed outputs_info "
...@@ -3303,6 +3306,70 @@ class T_Scan(unittest.TestCase): ...@@ -3303,6 +3306,70 @@ class T_Scan(unittest.TestCase):
theano.scan_module.scan_op.Scan)] theano.scan_module.scan_op.Scan)]
assert len(scan_nodes) == 1 assert len(scan_nodes) == 1
def test_eliminate_seqs(self):
U = tensor.vector('U')
sh = theano.shared(asarrayX(2.))
x1 = tensor.vector('x1')
x2 = tensor.scalar('x2')
def rec_fn(*args):
u_t = args[0]
return [(u_t + 1, # mitsot
u_t + 2, # sitsot
u_t + 3), # nitsot
{sh: u_t + 4}] # shared
[X1, X2, X3], updates = theano.scan(
rec_fn,
U,
[dict(initial=x1, taps=[-1, -3]), x2, None],
n_steps=None,
truncate_gradient=-1,
go_backwards=False)
f = theano.function([U, x1, x2], [X1, X2, X3],
updates=updates,
mode=theano.Mode(linker='py'),
allow_input_downcast=True)
rng = numpy.random.RandomState(utt.fetch_seed())
v_u = asarrayX(rng.uniform(size=(5,)))
outs = f(v_u, [0, 0, 0], 0)
assert numpy.allclose(outs[0], v_u + 1)
assert numpy.allclose(outs[1], v_u + 2)
assert numpy.allclose(outs[2], v_u + 3)
assert numpy.allclose(sh.get_value(), v_u[-1] + 4)
def test_eliminate_nonseqs(self):
W = tensor.scalar('W')
sh = theano.shared(asarrayX(2.))
x1 = tensor.vector('x1')
x2 = tensor.scalar('x2')
def rec_fn(*args):
w = args[-1]
return [(w + 1., # mitsot
w + 2., # sitsot
w + 3.), # nitsot
{sh: w + 4.}] # shared
[X1, X2, X3], updates = theano.scan(
rec_fn,
[],
[dict(initial=x1, taps=[-1, -3]), x2, None],
W,
n_steps=5,
truncate_gradient=-1,
go_backwards=False)
f = theano.function([W, x1, x2], [X1, X2, X3],
updates=updates,
mode=theano.Mode(linker='py'),
allow_input_downcast=True)
rng = numpy.random.RandomState(utt.fetch_seed())
v_w = asarrayX(rng.uniform())
outs = f(v_w, [0, 0, 0], 0)
assert numpy.allclose(outs[0], v_w + 1)
assert numpy.allclose(outs[1], v_w + 2)
assert numpy.allclose(outs[2], v_w + 3)
assert numpy.allclose(sh.get_value(), v_w + 4)
def test_speed(): def test_speed():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论