提交 e6633d20 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #1578 from lamblin/fix_scan_merge_inouts

Fix scan merge inouts
......@@ -6,6 +6,7 @@ types that it can raise
"""
import sys
import theano
from theano.gof import graph
from theano.gof import utils
from theano.gof import toolbox
......@@ -431,6 +432,23 @@ class FunctionGraph(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops
return
if theano.config.compute_test_value != 'off':
try:
tval = theano.gof.op.get_test_value(r)
new_tval = theano.gof.op.get_test_value(new_r)
except AttributeError:
pass
else:
tval_shape = getattr(tval, 'shape', None)
new_tval_shape = getattr(new_tval, 'shape', None)
if tval_shape != new_tval_shape:
raise AssertionError(
"The replacement variable has a test value with "
"a shape different from the original variable's "
"test value. Original: %s, new: %s"
% (tval_shape, new_tval_shape),
r, new_r, str(reason))
for node, i in list(r.clients): # copy the client list for iteration
assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason)
......
......@@ -1280,6 +1280,9 @@ def scan_merge_inouts(node):
if not isinstance(node.op, scan_op.Scan):
return False
# Do a first pass to merge identical external inputs.
# Equivalent inputs will be stored in inp_equiv, then a new
# scan node created without duplicates.
a = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info)
......@@ -1332,7 +1335,9 @@ def scan_merge_inouts(node):
else:
na = a
# start again
# Now that the identical external inputs have been merged, we do a new
# loop in order to merge external outputs that compute the same things
# from the same inputs.
left = []
right = []
......@@ -1369,32 +1374,42 @@ def scan_merge_inouts(node):
else:
seen[(oms, sl)] = ims
def map_out(i, o, seen):
for si, so in seen:
if equal_computations([i], [si], left, right):
return so
seen.append((i, o))
return o
def map_nitsot_out(i, o, sh, seen):
for p, (si, so, ssh) in enumerate(seen):
if equal_computations([i], [si], left, right):
def map_out(outer_i, inner_o, outer_o, seen):
# Return the outer input corresponding to an
# (outer input, inner output) pair. If we see that pair for the first
# time, return the provided outer output. If an equivalent pair had
# already been seen, return that one instead.
# Note that we need to check that the outer input match as well,
# because they could have different sizes, and the corresponding
# outer outputs cannot be merged in that case.
for s_outer_i, s_inner_o, s_outer_o in seen:
if (equal_computations([inner_o], [s_inner_o], left, right)
and outer_i == s_outer_i):
return s_outer_o
seen.append((outer_i, inner_o, outer_o))
return outer_o
def map_nitsot_out(outer_i, inner_o, outer_o, sh, seen):
# Like map_out, but also checks the needed shape.
for p, (s_outer_i, s_inner_o, s_outer_o, ssh) in enumerate(seen):
if (equal_computations([inner_o], [s_inner_o], left, right)
and outer_i == s_outer_i):
if equal_computations([sh], [ssh]):
return so
return s_outer_o
try:
vsh = int(opt.get_scalar_constant_value(sh))
vssh = int(opt.get_scalar_constant_value(ssh))
except tensor.NotScalarConstantError:
return o
return outer_o
if vsh == vssh:
return so
return s_outer_o
elif vsh > vssh:
seen[p] = (i, o, sh)
return o
seen[p] = (outer_i, inner_o, outer_o, sh)
return outer_o
else:
return so[:vsh]
seen.append((i, o, sh))
return o
return s_outer_o[:vsh]
seen.append((outer_i, inner_o, outer_o, sh))
return outer_o
seen = []
......@@ -1410,36 +1425,52 @@ def scan_merge_inouts(node):
# If x is a scalar, then it means its value is the number of
# items scan is supposed to store for this nit_sot sequence
shapes.append(x)
tmp = [map_nitsot_out(i, o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
na.outer_out_nit_sot = [map_nitsot_out(i, o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
assert len(na.outer_in_nit_sot) == len(na.inner_out_nit_sot)
assert len(na.inner_out_nit_sot) == len(na.outer_out_nit_sot)
assert len(na.outer_out_nit_sot) == len(shapes)
na.outer_out_nit_sot = [
map_nitsot_out(outer_i, inner_o, outer_o, sh, seen)
for outer_i, inner_o, outer_o, sh in zip(na.outer_in_nit_sot,
na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
seen = []
na.outer_out_sit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_sit_sot,
na.outer_out_sit_sot)]
assert len(na.outer_in_sit_sot) == len(na.inner_out_sit_sot)
assert len(na.inner_out_sit_sot) == len(na.outer_out_sit_sot)
na.outer_out_sit_sot = [
map_out(outer_i, inner_o, outer_o, seen)
for outer_i, inner_o, outer_o in zip(na.outer_in_sit_sot,
na.inner_out_sit_sot,
na.outer_out_sit_sot)]
seen = []
na.outer_out_mit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_mit_sot,
na.outer_out_mit_sot)]
assert len(na.outer_in_mit_sot) == len(na.inner_out_mit_sot)
assert len(na.inner_out_mit_sot) == len(na.outer_out_mit_sot)
na.outer_out_mit_sot = [
map_out(outer_i, inner_o, outer_o, seen)
for outer_i, inner_o, outer_o in zip(na.outer_in_mit_sot,
na.inner_out_mit_sot,
na.outer_out_mit_sot)]
seen = []
new_outer_out_mit_mot = []
for imm, omm, osl in zip(na.inner_out_mit_mot,
na.outer_out_mit_mot, na.mit_mot_out_slices):
for simm, somm, sosl in seen:
if osl == sosl and equal_computations(imm, simm, left, right):
new_outer_out_mit_mot.append(somm)
assert len(na.outer_in_mit_mot) == len(na.inner_out_mit_mot)
assert len(na.inner_out_mit_mot) == len(na.outer_out_mit_mot)
assert len(na.outer_out_mit_mot) == len(na.mit_mot_out_slices)
for outer_imm, inner_omm, outer_omm, osl in zip(na.outer_in_mit_mot,
na.inner_out_mit_mot,
na.outer_out_mit_mot,
na.mit_mot_out_slices):
for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen:
if (osl == sosl
and equal_computations(inner_omm, s_inner_omm, left, right)
and outer_imm == s_outer_imm):
new_outer_out_mit_mot.append(s_outer_omm)
break
else:
seen.append((imm, omm, osl))
new_outer_out_mit_mot.append(omm)
seen.append((outer_imm, inner_omm, outer_omm, osl))
new_outer_out_mit_mot.append(outer_omm)
na.outer_out_mit_mot = new_outer_out_mit_mot
return na.outer_outputs
......
import numpy
import unittest
import theano
from theano import config
from theano import tensor as T
from theano.tests import unittest_tools as utt
mode = theano.compile.mode.get_mode(config.mode)
class TestGaussNewton(unittest.TestCase):
"""
Regression test for code exhibiting various optimization errors.
This test case is based on code by Sigurd Spieckermann.
"""
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed())
def _run(self, num_features, num_timesteps, batch_size, mode):
# determine shapes of inputs and targets depending on the batch size
if batch_size == 1:
inputs_size = (num_timesteps, num_features)
targets_size = (num_timesteps, 1)
else:
inputs_size = (num_timesteps, batch_size, num_features)
targets_size = (num_timesteps, batch_size, 1)
# make inputs and targets shared variables
inputs = theano.shared(
self.rng.uniform(size=inputs_size).astype(config.floatX),
borrow=True)
targets = theano.shared(
self.rng.uniform(size=targets_size).astype(config.floatX),
borrow=True)
# create symbolic inputs and targets variables
if batch_size == 1:
x = T.matrix('inputs')
t = T.matrix('targets')
else:
x = T.tensor3('inputs')
t = T.tensor3('inputs')
x.tag.test_value = inputs.get_value(borrow=True)
t.tag.test_value = targets.get_value(borrow=True)
# create a set of parameters for a simple RNN
W_xh = theano.shared(
(0.01 * self.rng.uniform(
size=(num_features, 10))).astype(config.floatX),
borrow=True)
W_hh = theano.shared(
(0.01 * self.rng.uniform(size=(10, 10))).astype(config.floatX),
borrow=True)
W_hy = theano.shared(
(0.01 * self.rng.uniform(size=(10, 1))).astype(config.floatX),
borrow=True)
b_h = theano.shared(numpy.zeros(10).astype(config.floatX), borrow=True)
b_y = theano.shared(numpy.zeros(1).astype(config.floatX), borrow=True)
params = [W_xh, W_hh, W_hy, b_h, b_y]
# recurrent function
def step(x_t, h_tm1):
h = T.tanh(T.dot(h_tm1, W_hh) + T.dot(x_t, W_xh) + b_h)
return h
# build recurrent graph
if batch_size == 1:
h_0 = T.alloc(0.0, 10).astype(config.floatX)
else:
h_0 = T.alloc(0.0, batch_size, 10).astype(config.floatX)
h, updates = theano.scan(step,
sequences=[x],
outputs_info=[h_0])
# network output
y = T.dot(h, W_hy) + b_y
# Create Gauss-Newton-Matrix object. Not really of any use here, but I
# need it for Hessian-Free optimization.
gn = GaussNewtonMatrix(y)
# compute MSE
cost = ((t - y) ** 2).sum(axis=1).mean()
# Compute the cost at some other point in the parameter
# space. Not really of any use here, but this is how I do it
# during certain iterations of CG in the HF algorithm. There,
# it's in fact `pi + current update proposal`. For simplicity,
# I just multiply by 2 here.
cost_ = theano.clone(cost,
replace=dict([(pi, 2 * pi) for pi in params]))
# Compute Gauss-Newton-Matrix times some vector `v` which is `p` in CG,
# but for simplicity, I just take the parameters vector because it's
# already there.
Gv = gn(v=params, cost=cost, parameters=params, damp=T.constant(1.0))
# compile Theano function
f = theano.function([], [cost_] + Gv, givens={x: inputs, t: targets},
mode=mode)
# execute
f()
def test_batch(self):
# This runs fine. The batch size is set to something greater than 1,
# i.e. the data is represented by a tensor3 object.
self._run(100, 10, batch_size=5, mode=mode)
def test_nobatch(self):
# This used to give an error due to optimization "scan_merge_inouts".
# The batch size is set to 1 and the data is represented by a matrix.
# As of 2013-10-24, it still triggers an optimization error due to
# "remove_constants_and_unused_inputs_scan".
mode_exc = mode.excluding("remove_constants_and_unused_inputs_scan")
self._run(100, 10, batch_size=1, mode=mode_exc)
class GaussNewtonMatrix(object):
def __init__(self, s):
# `s` is the linear network outputs, i.e. the network output
# without having applied the activation function
self._s = s
def __call__(self, v, cost, parameters, damp):
# compute Gauss-Newton Matrix right-multiplied by `v`
Jv = T.Rop(self._s, parameters, v)
HJv = T.grad(T.sum(T.grad(cost, self._s) * Jv), self._s,
consider_constant=[Jv])
JHJv = T.grad(T.sum(HJv * self._s), parameters,
consider_constant=[HJv, Jv])
# apply Tikhonov damping
JHJv = [JHJvi + damp * vi for JHJvi, vi in zip(JHJv, v)]
return JHJv
......@@ -2579,7 +2579,7 @@ class Alloc(gof.Op):
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs
return self(eval_points[0], *inputs[1:], **dict(return_list=True))
def do_constant_folding(self, node):
if not getattr(node.outputs[0], 'clients', []):
......@@ -3275,7 +3275,7 @@ class Rebroadcast(Op):
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self.make_node(*eval_points).outputs
return self(*eval_points, **dict(return_list=True))
def addbroadcast(x, *axes):
......@@ -3805,7 +3805,7 @@ class Reshape(Op):
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs
return self(eval_points[0], *inputs[1:], **dict(return_list=True))
def infer_shape(self, node, ishapes):
# inputs[1] can contain at most one value of '-1', meaning the actual
......@@ -4600,7 +4600,7 @@ class Dot(Op):
eval_point_values = [ev0, ev1]
for i in xrange(2):
if eval_point_values[i] and \
if eval_point_values[i] is not None and \
input_values[i].shape != eval_point_values[i].shape:
raise ValueError('input ' + str(i) + ' and eval_point ' +
str(i) + ' to Dot.R_op '
......
......@@ -273,7 +273,7 @@ class DimShuffle(Op):
def R_op(self, inputs, eval_points):
if None in eval_points:
return [None]
return self.make_node(*eval_points).outputs
return self(*eval_points, **dict(return_list=True))
def c_code(self, node, name, inp, out, sub):
input, = inp
......@@ -616,7 +616,7 @@ class Elemwise(Op):
return self.name
def R_op(self, inputs, eval_points):
outs = self.make_node(*inputs).outputs
outs = self(*inputs, **dict(return_list=True))
rval = [None for x in outs]
# For each output
for idx, out in enumerate(outs):
......@@ -1882,7 +1882,7 @@ class Sum(CAReduceDtype):
# part of self
if None in eval_points:
return [None]
return self.make_node(*eval_points).outputs
return self(*eval_points, **dict(return_list=True))
def __str__(self):
if self.axis is None:
......
......@@ -263,10 +263,11 @@ def inplace_elemwise_optimizer_op(OP):
scalar.transfer_type(
*[inplace_pattern.get(i, None) \
for i in xrange(len(node.outputs))]))
new = OP(new_scal, inplace_pattern).make_node(
*node.inputs)
new_outputs = OP(new_scal, inplace_pattern)(
*node.inputs, **dict(return_list=True))
new_node = new_outputs[0].owner
for r, new_r in zip(node.outputs, new.outputs):
for r, new_r in zip(node.outputs, new_outputs):
fgraph.replace(r, new_r,
reason="inplace_elemwise_optimizer")
nb_change_no_validate += 1
......@@ -284,7 +285,7 @@ def inplace_elemwise_optimizer_op(OP):
fgraph.revert(chk)
continue
candidate_inputs.remove(candidate_input)
node = new
node = new_node
baseline = inplace_pattern
break
......
......@@ -883,7 +883,7 @@ class Subtensor(Op):
# (they should be defaulted to zeros_like by the global R_op)
if eval_points[0] is None:
return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs
return self(eval_points[0], *inputs[1:], **dict(return_list=True))
class SubtensorPrinter:
......@@ -1414,8 +1414,8 @@ class IncSubtensor(Op):
return [None]
# Again we ignore eval points for indices because incsubtensor is
# not differentiable wrt to those
return self.make_node(eval_points[0], eval_points[1],
*inputs[2:]).outputs
return self(eval_points[0], eval_points[1], *inputs[2:],
**dict(return_list=True))
def connection_pattern(self, node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论