提交 a7b9a7e8 authored 作者: abergeron's avatar abergeron

Merge pull request #1970 from nouiz/opt_disabled_fix

[MRG, BUG] Scan Opt fixes, enhencements
...@@ -380,9 +380,12 @@ class _metadict: ...@@ -380,9 +380,12 @@ class _metadict:
self.l.append((item, value)) self.l.append((item, value))
def __delitem__(self, item): def __delitem__(self, item):
try:
if item in self.d: if item in self.d:
del self.d[item] del self.d[item]
else: return
except TypeError, e:
assert "unhashable type" in str(e)
for i, (key, val) in enumerate(self.l): for i, (key, val) in enumerate(self.l):
if key == item: if key == item:
del self.l[i] del self.l[i]
...@@ -390,9 +393,12 @@ class _metadict: ...@@ -390,9 +393,12 @@ class _metadict:
raise KeyError(item) raise KeyError(item)
def discard(self, item): def discard(self, item):
try:
if item in self.d: if item in self.d:
del self.d[item] del self.d[item]
else: return
except TypeError, e:
assert "unhashable type" in str(e)
for i, (key, val) in enumerate(self.l): for i, (key, val) in enumerate(self.l):
if key == item: if key == item:
del self.l[i] del self.l[i]
...@@ -736,9 +742,14 @@ def pre_constant_merge(vars): ...@@ -736,9 +742,14 @@ def pre_constant_merge(vars):
seen_var.add(var) seen_var.add(var)
if isinstance(var, graph.Constant): if isinstance(var, graph.Constant):
sig = var.signature() sig = var.signature()
try:
if sig in const_sig_inv: if sig in const_sig_inv:
return const_sig_inv[sig] return const_sig_inv[sig]
const_sig_inv[sig] = var const_sig_inv[sig] = var
except TypeError: # unhashable type
# Some python object like slice aren't hashable. So
# don't merge them here.
pass
return var return var
if var.owner: if var.owner:
for idx, inp in enumerate(var.owner.inputs): for idx, inp in enumerate(var.owner.inputs):
......
...@@ -82,6 +82,11 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -82,6 +82,11 @@ def debugprint(obj, depth=-1, print_type=False,
done = dict() done = dict()
results_to_print = [] results_to_print = []
order = [] order = []
if isinstance(obj, (list, tuple)):
lobj = obj
else:
lobj = [obj]
for obj in lobj:
if isinstance(obj, gof.Variable): if isinstance(obj, gof.Variable):
results_to_print.append(obj) results_to_print.append(obj)
elif isinstance(obj, gof.Apply): elif isinstance(obj, gof.Apply):
...@@ -89,15 +94,14 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -89,15 +94,14 @@ def debugprint(obj, depth=-1, print_type=False,
elif isinstance(obj, Function): elif isinstance(obj, Function):
results_to_print.extend(obj.maker.fgraph.outputs) results_to_print.extend(obj.maker.fgraph.outputs)
order = obj.maker.fgraph.toposort() order = obj.maker.fgraph.toposort()
elif isinstance(obj, (list, tuple)):
results_to_print.extend(obj)
elif isinstance(obj, gof.FunctionGraph): elif isinstance(obj, gof.FunctionGraph):
results_to_print.extend(obj.outputs) results_to_print.extend(obj.outputs)
order = obj.toposort() order = obj.toposort()
elif isinstance(obj, (int, long, float, numpy.ndarray)): elif isinstance(obj, (int, long, float, numpy.ndarray)):
print obj print obj
else: else:
raise TypeError("debugprint cannot print an object of this type", obj) raise TypeError("debugprint cannot print an object of this type",
obj)
for r in results_to_print: for r in results_to_print:
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=order, ids=ids, file=_file, order=order, ids=ids,
......
import theano import theano
import numpy
import theano.tensor import theano.tensor
class ScalarSoftsign(theano.scalar.UnaryScalarOp): class ScalarSoftsign(theano.scalar.UnaryScalarOp):
@staticmethod @staticmethod
def static_impl(x): def static_impl(x):
return x / (1.0 + abs(x)) return x / (1.0 + abs(x))
def impl(self, x): def impl(self, x):
return ScalarSoftsign.static_impl(x) return ScalarSoftsign.static_impl(x)
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
gz, = grads gz, = grads
...@@ -17,11 +19,15 @@ class ScalarSoftsign(theano.scalar.UnaryScalarOp): ...@@ -17,11 +19,15 @@ class ScalarSoftsign(theano.scalar.UnaryScalarOp):
return [gz / (d * d)] return [gz / (d * d)]
else: else:
return NotImplemented return NotImplemented
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
z, = out z, = out
if node.inputs[0].type in [theano.scalar.float32, theano.scalar.float64]: if node.inputs[0].type in [theano.scalar.float32,
theano.scalar.float64]:
return "%(z)s = %(x)s / (1.0+fabs(%(x)s));" % locals() return "%(z)s = %(x)s / (1.0+fabs(%(x)s));" % locals()
raise NotImplementedError('only floating point x is implemented') raise NotImplementedError('only floating point x is implemented')
scalar_softsign = ScalarSoftsign(theano.scalar.upgrade_to_float, name='scalar_softsign')
scalar_softsign = ScalarSoftsign(theano.scalar.upgrade_to_float,
name='scalar_softsign')
softsign = theano.tensor.Elemwise(scalar_softsign, name='softsign') softsign = theano.tensor.Elemwise(scalar_softsign, name='softsign')
...@@ -835,15 +835,14 @@ class Scan(PureOp): ...@@ -835,15 +835,14 @@ class Scan(PureOp):
n_steps = args[0] n_steps = args[0]
seqs = [] seqs = []
if n_steps < 0: if n_steps < 0:
n_steps = abs(n_steps) # History, in the past, this was used for backward
for idx, seq in enumerate(args[1:self.seqs_arg_offset]): # scan. Now we reverse the inputs outside of scan.
if seq.shape[0] < n_steps: raise IndexError(
raise ValueError(('Sequence is shorter then the required ' "Scan was asked to run for negative number of step %d" %
'number of steps : (n_steps, seq, ' n_steps)
'seq.shape):'), n_steps, elif n_steps == 0:
node.inputs[1 + idx], raise NotImplementedError(
seq.shape) "We didn't implemented yet the case where scan do 0 iteration")
seqs.append(seq[::-1])
else: else:
for idx, seq in enumerate(args[1:self.seqs_arg_offset]): for idx, seq in enumerate(args[1:self.seqs_arg_offset]):
if seq.shape[0] < n_steps: if seq.shape[0] < n_steps:
...@@ -1285,6 +1284,11 @@ class Scan(PureOp): ...@@ -1285,6 +1284,11 @@ class Scan(PureOp):
return ipos + opos return ipos + opos
def connection_pattern(self, node): def connection_pattern(self, node):
# We cache this, as grad call connection_pattern, and it call
# grad in its turn. I was a case where theano.grad() took 4h
# that had many scan one inside each others.
if hasattr(node.tag, 'connection_pattern'):
return node.tag.connection_pattern
# The gradient wrt to n_steps is disconnected # The gradient wrt to n_steps is disconnected
connection_pattern = [[False for output in node.outputs]] connection_pattern = [[False for output in node.outputs]]
connection_pattern += [[False for output in node.outputs] connection_pattern += [[False for output in node.outputs]
...@@ -1391,6 +1395,8 @@ class Scan(PureOp): ...@@ -1391,6 +1395,8 @@ class Scan(PureOp):
for k in xrange(len(connection_pattern)): for k in xrange(len(connection_pattern)):
if connection_pattern[k][jidx]: if connection_pattern[k][jidx]:
connection_pattern[k][iidx] = True connection_pattern[k][iidx] = True
node.tag.connection_pattern = connection_pattern
return connection_pattern return connection_pattern
### GRAD FUNCTION ### GRAD FUNCTION
......
""" """
This module provides optimizations for scan This module provides optimizations for scan
The Optimization provided in this file:
local opt: remove_constants_and_unused_inputs_scan,
constant_folding_for_scan2,
scan_merge_inouts
They are wrapped in in2out to create global opt.
global opt: ScanInplaceOptimizer,
PushOutNonSeqScan,
PushOutSeqScan,
PushOutDot1,
ScanMerge,
ScanSaveMem
How the are registered:
optdb: scan_eqopt1 (.1), scan_eqopt2(1.6), scan_inplace(75)
scan_eqopt1 -> scan_seqopt1
scan_seqopt1 -> in2out(remove_constants_and_unused_inputs_scan)(1),
PushOutNonSeqScan(2),
PushOutSeqScan(3), PushOutDot1(4)
scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
This is important, as the order is important and all global
optimizer run before local optimizer in the order they where
registered. (So don't change the order we register them!)
If we convert to local optimizer, we must convert all of them
to local optimizer. But:
1) can ScanMerge be made local? Can we keep only this one global?
2) ScanSaveMem assert that we remove all nodes outputs,
we need to keep this.
3) It is ScanSaveMem suppose the the others ran before.
I added an assert at one place, but didn't looked for other place.
4) Moving this to local opt could speed up significant this opt,
as we pass frequently on all nodes in the graph for no good reason.
5) We register remove_constant_* many places, as some
opt create them and let this one clean up the mess.
Doing it that way, make things simpler for those already
complex opt.
in2out(constant_folding),
in2out(remove_constants_and_unused_inputs_scan1),
ScanMerge,
in2out(remove_constants_and_unused_inputs_scan2),
in2out(scan_merge_inouts),
ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3)
""" """
...@@ -858,6 +903,20 @@ class ScanSaveMem(gof.Optimizer): ...@@ -858,6 +903,20 @@ class ScanSaveMem(gof.Optimizer):
if store_steps[i] != -1: if store_steps[i] != -1:
pval = select_max(pval, store_steps[i]) pval = select_max(pval, store_steps[i])
# TODO: Simplify the number of steps needed.
# FB: This need good testing, left to later.
# call get_scalar_constant_value()? it can
# return python/numpy scalar or numpy.ndarray currently.
#pval = pre_greedy_local_optimizer(list_opt_slice,
# pval)
#pval = pre_constant_merge([pval])[0]
#if (isinstance(pval, theano.tensor.TensorConstant) and
# pval.dtype.startswith('int')):
# try:
# pval = int(pval.data)
# except Exception:
# pass
store_steps[i] = pval store_steps[i] = pval
flag_store = True flag_store = True
...@@ -904,6 +963,8 @@ class ScanSaveMem(gof.Optimizer): ...@@ -904,6 +963,8 @@ class ScanSaveMem(gof.Optimizer):
nw_inputs[offset + idx].owner.op.idx_list[0], nw_inputs[offset + idx].owner.op.idx_list[0],
slice)): slice)):
assert isinstance(nw_inputs[offset + idx].owner.op,
tensor.IncSubtensor)
_nw_input = nw_inputs[offset + idx].owner.inputs[1] _nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = tensor.as_tensor_variable(val) cval = tensor.as_tensor_variable(val)
initl = tensor.as_tensor_variable(init_l[i]) initl = tensor.as_tensor_variable(init_l[i])
...@@ -947,7 +1008,6 @@ class ScanSaveMem(gof.Optimizer): ...@@ -947,7 +1008,6 @@ class ScanSaveMem(gof.Optimizer):
if val == 0: if val == 0:
if idx < op.n_mit_sot + op.n_sit_sot: if idx < op.n_mit_sot + op.n_sit_sot:
_nw_input = nw_inputs[offset + idx].owner.inputs[1] _nw_input = nw_inputs[offset + idx].owner.inputs[1]
odx = op.n_mit_mot + idx
nw_input = scan_utils.expand(_nw_input, nw_steps) nw_input = scan_utils.expand(_nw_input, nw_steps)
nw_inputs[offset + idx] = nw_input nw_inputs[offset + idx] = nw_input
elif idx < (op.n_mit_sot + op.n_sit_sot + elif idx < (op.n_mit_sot + op.n_sit_sot +
...@@ -955,7 +1015,6 @@ class ScanSaveMem(gof.Optimizer): ...@@ -955,7 +1015,6 @@ class ScanSaveMem(gof.Optimizer):
in_idx = offset + idx + op.n_shared_outs in_idx = offset + idx + op.n_shared_outs
if nw_inputs[in_idx] == node.inputs[0]: if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] = nw_steps nw_inputs[in_idx] = nw_steps
odx = op.n_mit_mot + idx
# 3.5 Remove unwanted orphane outputs # 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \ (inps, outs, info, node_ins, compress_map) = \
...@@ -970,8 +1029,16 @@ class ScanSaveMem(gof.Optimizer): ...@@ -970,8 +1029,16 @@ class ScanSaveMem(gof.Optimizer):
# 3.6 Compose the new scan # 3.6 Compose the new scan
# I need to make sure I'm not reapplying the same optimization # I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that # twice since bad things usually happen if I do that
# TODO: why not check if save mem was done on any of merged nodes?
# That way, if none of them had save mem applied, it would
# be applied later.
info['_scan_savemem_visited'] = True info['_scan_savemem_visited'] = True
# TODO: currently we don't support scan with 0 step. So
# don't create one.
if theano.tensor.extract_constant(node_ins[0]) == 0:
return
# Do not call make_node for test_value # Do not call make_node for test_value
new_outs = scan_op.Scan(inps, outs, info)(*node_ins, new_outs = scan_op.Scan(inps, outs, info)(*node_ins,
**dict(return_list=True)) **dict(return_list=True))
......
This source diff could not be displayed because it is too large. You can view the blob instead.
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论