提交 a7b9a7e8 authored 作者: abergeron's avatar abergeron

Merge pull request #1970 from nouiz/opt_disabled_fix

[MRG, BUG] Scan Opt fixes, enhencements
......@@ -380,23 +380,29 @@ class _metadict:
self.l.append((item, value))
def __delitem__(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
try:
if item in self.d:
del self.d[item]
return
except TypeError, e:
assert "unhashable type" in str(e)
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
raise KeyError(item)
def discard(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
try:
if item in self.d:
del self.d[item]
return
except TypeError, e:
assert "unhashable type" in str(e)
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
def get(self, item, default):
try:
......@@ -736,9 +742,14 @@ def pre_constant_merge(vars):
seen_var.add(var)
if isinstance(var, graph.Constant):
sig = var.signature()
if sig in const_sig_inv:
return const_sig_inv[sig]
const_sig_inv[sig] = var
try:
if sig in const_sig_inv:
return const_sig_inv[sig]
const_sig_inv[sig] = var
except TypeError: # unhashable type
# Some python object like slice aren't hashable. So
# don't merge them here.
pass
return var
if var.owner:
for idx, inp in enumerate(var.owner.inputs):
......
......@@ -82,22 +82,26 @@ def debugprint(obj, depth=-1, print_type=False,
done = dict()
results_to_print = []
order = []
if isinstance(obj, gof.Variable):
results_to_print.append(obj)
elif isinstance(obj, gof.Apply):
results_to_print.extend(obj.outputs)
elif isinstance(obj, Function):
results_to_print.extend(obj.maker.fgraph.outputs)
order = obj.maker.fgraph.toposort()
elif isinstance(obj, (list, tuple)):
results_to_print.extend(obj)
elif isinstance(obj, gof.FunctionGraph):
results_to_print.extend(obj.outputs)
order = obj.toposort()
elif isinstance(obj, (int, long, float, numpy.ndarray)):
print obj
if isinstance(obj, (list, tuple)):
lobj = obj
else:
raise TypeError("debugprint cannot print an object of this type", obj)
lobj = [obj]
for obj in lobj:
if isinstance(obj, gof.Variable):
results_to_print.append(obj)
elif isinstance(obj, gof.Apply):
results_to_print.extend(obj.outputs)
elif isinstance(obj, Function):
results_to_print.extend(obj.maker.fgraph.outputs)
order = obj.maker.fgraph.toposort()
elif isinstance(obj, gof.FunctionGraph):
results_to_print.extend(obj.outputs)
order = obj.toposort()
elif isinstance(obj, (int, long, float, numpy.ndarray)):
print obj
else:
raise TypeError("debugprint cannot print an object of this type",
obj)
for r in results_to_print:
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=order, ids=ids,
......
import theano
import numpy
import theano.tensor
class ScalarSoftsign(theano.scalar.UnaryScalarOp):
@staticmethod
def static_impl(x):
return x / (1.0 + abs(x))
def impl(self, x):
return ScalarSoftsign.static_impl(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
......@@ -17,11 +19,15 @@ class ScalarSoftsign(theano.scalar.UnaryScalarOp):
return [gz / (d * d)]
else:
return NotImplemented
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in [theano.scalar.float32, theano.scalar.float64]:
if node.inputs[0].type in [theano.scalar.float32,
theano.scalar.float64]:
return "%(z)s = %(x)s / (1.0+fabs(%(x)s));" % locals()
raise NotImplementedError('only floating point x is implemented')
scalar_softsign = ScalarSoftsign(theano.scalar.upgrade_to_float, name='scalar_softsign')
scalar_softsign = ScalarSoftsign(theano.scalar.upgrade_to_float,
name='scalar_softsign')
softsign = theano.tensor.Elemwise(scalar_softsign, name='softsign')
......@@ -835,15 +835,14 @@ class Scan(PureOp):
n_steps = args[0]
seqs = []
if n_steps < 0:
n_steps = abs(n_steps)
for idx, seq in enumerate(args[1:self.seqs_arg_offset]):
if seq.shape[0] < n_steps:
raise ValueError(('Sequence is shorter then the required '
'number of steps : (n_steps, seq, '
'seq.shape):'), n_steps,
node.inputs[1 + idx],
seq.shape)
seqs.append(seq[::-1])
# History, in the past, this was used for backward
# scan. Now we reverse the inputs outside of scan.
raise IndexError(
"Scan was asked to run for negative number of step %d" %
n_steps)
elif n_steps == 0:
raise NotImplementedError(
"We didn't implemented yet the case where scan do 0 iteration")
else:
for idx, seq in enumerate(args[1:self.seqs_arg_offset]):
if seq.shape[0] < n_steps:
......@@ -1285,6 +1284,11 @@ class Scan(PureOp):
return ipos + opos
def connection_pattern(self, node):
# We cache this, as grad call connection_pattern, and it call
# grad in its turn. I was a case where theano.grad() took 4h
# that had many scan one inside each others.
if hasattr(node.tag, 'connection_pattern'):
return node.tag.connection_pattern
# The gradient wrt to n_steps is disconnected
connection_pattern = [[False for output in node.outputs]]
connection_pattern += [[False for output in node.outputs]
......@@ -1391,6 +1395,8 @@ class Scan(PureOp):
for k in xrange(len(connection_pattern)):
if connection_pattern[k][jidx]:
connection_pattern[k][iidx] = True
node.tag.connection_pattern = connection_pattern
return connection_pattern
### GRAD FUNCTION
......
"""
This module provides optimizations for scan
The Optimization provided in this file:
local opt: remove_constants_and_unused_inputs_scan,
constant_folding_for_scan2,
scan_merge_inouts
They are wrapped in in2out to create global opt.
global opt: ScanInplaceOptimizer,
PushOutNonSeqScan,
PushOutSeqScan,
PushOutDot1,
ScanMerge,
ScanSaveMem
How the are registered:
optdb: scan_eqopt1 (.1), scan_eqopt2(1.6), scan_inplace(75)
scan_eqopt1 -> scan_seqopt1
scan_seqopt1 -> in2out(remove_constants_and_unused_inputs_scan)(1),
PushOutNonSeqScan(2),
PushOutSeqScan(3), PushOutDot1(4)
scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
This is important, as the order is important and all global
optimizer run before local optimizer in the order they where
registered. (So don't change the order we register them!)
If we convert to local optimizer, we must convert all of them
to local optimizer. But:
1) can ScanMerge be made local? Can we keep only this one global?
2) ScanSaveMem assert that we remove all nodes outputs,
we need to keep this.
3) It is ScanSaveMem suppose the the others ran before.
I added an assert at one place, but didn't looked for other place.
4) Moving this to local opt could speed up significant this opt,
as we pass frequently on all nodes in the graph for no good reason.
5) We register remove_constant_* many places, as some
opt create them and let this one clean up the mess.
Doing it that way, make things simpler for those already
complex opt.
in2out(constant_folding),
in2out(remove_constants_and_unused_inputs_scan1),
ScanMerge,
in2out(remove_constants_and_unused_inputs_scan2),
in2out(scan_merge_inouts),
ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3)
"""
......@@ -858,6 +903,20 @@ class ScanSaveMem(gof.Optimizer):
if store_steps[i] != -1:
pval = select_max(pval, store_steps[i])
# TODO: Simplify the number of steps needed.
# FB: This need good testing, left to later.
# call get_scalar_constant_value()? it can
# return python/numpy scalar or numpy.ndarray currently.
#pval = pre_greedy_local_optimizer(list_opt_slice,
# pval)
#pval = pre_constant_merge([pval])[0]
#if (isinstance(pval, theano.tensor.TensorConstant) and
# pval.dtype.startswith('int')):
# try:
# pval = int(pval.data)
# except Exception:
# pass
store_steps[i] = pval
flag_store = True
......@@ -904,6 +963,8 @@ class ScanSaveMem(gof.Optimizer):
nw_inputs[offset + idx].owner.op.idx_list[0],
slice)):
assert isinstance(nw_inputs[offset + idx].owner.op,
tensor.IncSubtensor)
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = tensor.as_tensor_variable(val)
initl = tensor.as_tensor_variable(init_l[i])
......@@ -947,7 +1008,6 @@ class ScanSaveMem(gof.Optimizer):
if val == 0:
if idx < op.n_mit_sot + op.n_sit_sot:
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
odx = op.n_mit_mot + idx
nw_input = scan_utils.expand(_nw_input, nw_steps)
nw_inputs[offset + idx] = nw_input
elif idx < (op.n_mit_sot + op.n_sit_sot +
......@@ -955,7 +1015,6 @@ class ScanSaveMem(gof.Optimizer):
in_idx = offset + idx + op.n_shared_outs
if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] = nw_steps
odx = op.n_mit_mot + idx
# 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \
......@@ -970,8 +1029,16 @@ class ScanSaveMem(gof.Optimizer):
# 3.6 Compose the new scan
# I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that
# TODO: why not check if save mem was done on any of merged nodes?
# That way, if none of them had save mem applied, it would
# be applied later.
info['_scan_savemem_visited'] = True
# TODO: currently we don't support scan with 0 step. So
# don't create one.
if theano.tensor.extract_constant(node_ins[0]) == 0:
return
# Do not call make_node for test_value
new_outs = scan_op.Scan(inps, outs, info)(*node_ins,
**dict(return_list=True))
......
This source diff could not be displayed because it is too large. You can view the blob instead.
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论