提交 e6633d20 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #1578 from lamblin/fix_scan_merge_inouts

Fix scan merge inouts
...@@ -6,6 +6,7 @@ types that it can raise ...@@ -6,6 +6,7 @@ types that it can raise
""" """
import sys import sys
import theano
from theano.gof import graph from theano.gof import graph
from theano.gof import utils from theano.gof import utils
from theano.gof import toolbox from theano.gof import toolbox
...@@ -431,6 +432,23 @@ class FunctionGraph(utils.object2): ...@@ -431,6 +432,23 @@ class FunctionGraph(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops # because it makes it easier to implement some optimizations for multiple-output ops
return return
if theano.config.compute_test_value != 'off':
try:
tval = theano.gof.op.get_test_value(r)
new_tval = theano.gof.op.get_test_value(new_r)
except AttributeError:
pass
else:
tval_shape = getattr(tval, 'shape', None)
new_tval_shape = getattr(new_tval, 'shape', None)
if tval_shape != new_tval_shape:
raise AssertionError(
"The replacement variable has a test value with "
"a shape different from the original variable's "
"test value. Original: %s, new: %s"
% (tval_shape, new_tval_shape),
r, new_r, str(reason))
for node, i in list(r.clients): # copy the client list for iteration for node, i in list(r.clients): # copy the client list for iteration
assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r) assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason) self.change_input(node, i, new_r, reason=reason)
......
...@@ -1280,6 +1280,9 @@ def scan_merge_inouts(node): ...@@ -1280,6 +1280,9 @@ def scan_merge_inouts(node):
if not isinstance(node.op, scan_op.Scan): if not isinstance(node.op, scan_op.Scan):
return False return False
# Do a first pass to merge identical external inputs.
# Equivalent inputs will be stored in inp_equiv, then a new
# scan node created without duplicates.
a = scan_args(node.inputs, node.outputs, a = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info) node.op.inputs, node.op.outputs, node.op.info)
...@@ -1332,7 +1335,9 @@ def scan_merge_inouts(node): ...@@ -1332,7 +1335,9 @@ def scan_merge_inouts(node):
else: else:
na = a na = a
# start again # Now that the identical external inputs have been merged, we do a new
# loop in order to merge external outputs that compute the same things
# from the same inputs.
left = [] left = []
right = [] right = []
...@@ -1369,32 +1374,42 @@ def scan_merge_inouts(node): ...@@ -1369,32 +1374,42 @@ def scan_merge_inouts(node):
else: else:
seen[(oms, sl)] = ims seen[(oms, sl)] = ims
def map_out(i, o, seen): def map_out(outer_i, inner_o, outer_o, seen):
for si, so in seen: # Return the outer input corresponding to an
if equal_computations([i], [si], left, right): # (outer input, inner output) pair. If we see that pair for the first
return so # time, return the provided outer output. If an equivalent pair had
seen.append((i, o)) # already been seen, return that one instead.
return o # Note that we need to check that the outer input match as well,
# because they could have different sizes, and the corresponding
def map_nitsot_out(i, o, sh, seen): # outer outputs cannot be merged in that case.
for p, (si, so, ssh) in enumerate(seen): for s_outer_i, s_inner_o, s_outer_o in seen:
if equal_computations([i], [si], left, right): if (equal_computations([inner_o], [s_inner_o], left, right)
and outer_i == s_outer_i):
return s_outer_o
seen.append((outer_i, inner_o, outer_o))
return outer_o
def map_nitsot_out(outer_i, inner_o, outer_o, sh, seen):
# Like map_out, but also checks the needed shape.
for p, (s_outer_i, s_inner_o, s_outer_o, ssh) in enumerate(seen):
if (equal_computations([inner_o], [s_inner_o], left, right)
and outer_i == s_outer_i):
if equal_computations([sh], [ssh]): if equal_computations([sh], [ssh]):
return so return s_outer_o
try: try:
vsh = int(opt.get_scalar_constant_value(sh)) vsh = int(opt.get_scalar_constant_value(sh))
vssh = int(opt.get_scalar_constant_value(ssh)) vssh = int(opt.get_scalar_constant_value(ssh))
except tensor.NotScalarConstantError: except tensor.NotScalarConstantError:
return o return outer_o
if vsh == vssh: if vsh == vssh:
return so return s_outer_o
elif vsh > vssh: elif vsh > vssh:
seen[p] = (i, o, sh) seen[p] = (outer_i, inner_o, outer_o, sh)
return o return outer_o
else: else:
return so[:vsh] return s_outer_o[:vsh]
seen.append((i, o, sh)) seen.append((outer_i, inner_o, outer_o, sh))
return o return outer_o
seen = [] seen = []
...@@ -1410,36 +1425,52 @@ def scan_merge_inouts(node): ...@@ -1410,36 +1425,52 @@ def scan_merge_inouts(node):
# If x is a scalar, then it means its value is the number of # If x is a scalar, then it means its value is the number of
# items scan is supposed to store for this nit_sot sequence # items scan is supposed to store for this nit_sot sequence
shapes.append(x) shapes.append(x)
tmp = [map_nitsot_out(i, o, sh, seen) assert len(na.outer_in_nit_sot) == len(na.inner_out_nit_sot)
for i, o, sh in zip(na.inner_out_nit_sot, assert len(na.inner_out_nit_sot) == len(na.outer_out_nit_sot)
na.outer_out_nit_sot, assert len(na.outer_out_nit_sot) == len(shapes)
shapes)] na.outer_out_nit_sot = [
na.outer_out_nit_sot = [map_nitsot_out(i, o, sh, seen) map_nitsot_out(outer_i, inner_o, outer_o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot, for outer_i, inner_o, outer_o, sh in zip(na.outer_in_nit_sot,
na.inner_out_nit_sot,
na.outer_out_nit_sot, na.outer_out_nit_sot,
shapes)] shapes)]
seen = [] seen = []
na.outer_out_sit_sot = [map_out(i, o, seen) assert len(na.outer_in_sit_sot) == len(na.inner_out_sit_sot)
for i, o in zip(na.inner_out_sit_sot, assert len(na.inner_out_sit_sot) == len(na.outer_out_sit_sot)
na.outer_out_sit_sot = [
map_out(outer_i, inner_o, outer_o, seen)
for outer_i, inner_o, outer_o in zip(na.outer_in_sit_sot,
na.inner_out_sit_sot,
na.outer_out_sit_sot)] na.outer_out_sit_sot)]
seen = [] seen = []
na.outer_out_mit_sot = [map_out(i, o, seen) assert len(na.outer_in_mit_sot) == len(na.inner_out_mit_sot)
for i, o in zip(na.inner_out_mit_sot, assert len(na.inner_out_mit_sot) == len(na.outer_out_mit_sot)
na.outer_out_mit_sot = [
map_out(outer_i, inner_o, outer_o, seen)
for outer_i, inner_o, outer_o in zip(na.outer_in_mit_sot,
na.inner_out_mit_sot,
na.outer_out_mit_sot)] na.outer_out_mit_sot)]
seen = [] seen = []
new_outer_out_mit_mot = [] new_outer_out_mit_mot = []
for imm, omm, osl in zip(na.inner_out_mit_mot, assert len(na.outer_in_mit_mot) == len(na.inner_out_mit_mot)
na.outer_out_mit_mot, na.mit_mot_out_slices): assert len(na.inner_out_mit_mot) == len(na.outer_out_mit_mot)
for simm, somm, sosl in seen: assert len(na.outer_out_mit_mot) == len(na.mit_mot_out_slices)
if osl == sosl and equal_computations(imm, simm, left, right): for outer_imm, inner_omm, outer_omm, osl in zip(na.outer_in_mit_mot,
new_outer_out_mit_mot.append(somm) na.inner_out_mit_mot,
na.outer_out_mit_mot,
na.mit_mot_out_slices):
for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen:
if (osl == sosl
and equal_computations(inner_omm, s_inner_omm, left, right)
and outer_imm == s_outer_imm):
new_outer_out_mit_mot.append(s_outer_omm)
break break
else: else:
seen.append((imm, omm, osl)) seen.append((outer_imm, inner_omm, outer_omm, osl))
new_outer_out_mit_mot.append(omm) new_outer_out_mit_mot.append(outer_omm)
na.outer_out_mit_mot = new_outer_out_mit_mot na.outer_out_mit_mot = new_outer_out_mit_mot
return na.outer_outputs return na.outer_outputs
......
import numpy
import unittest
import theano
from theano import config
from theano import tensor as T
from theano.tests import unittest_tools as utt
mode = theano.compile.mode.get_mode(config.mode)
class TestGaussNewton(unittest.TestCase):
"""
Regression test for code exhibiting various optimization errors.
This test case is based on code by Sigurd Spieckermann.
"""
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed())
def _run(self, num_features, num_timesteps, batch_size, mode):
# determine shapes of inputs and targets depending on the batch size
if batch_size == 1:
inputs_size = (num_timesteps, num_features)
targets_size = (num_timesteps, 1)
else:
inputs_size = (num_timesteps, batch_size, num_features)
targets_size = (num_timesteps, batch_size, 1)
# make inputs and targets shared variables
inputs = theano.shared(
self.rng.uniform(size=inputs_size).astype(config.floatX),
borrow=True)
targets = theano.shared(
self.rng.uniform(size=targets_size).astype(config.floatX),
borrow=True)
# create symbolic inputs and targets variables
if batch_size == 1:
x = T.matrix('inputs')
t = T.matrix('targets')
else:
x = T.tensor3('inputs')
t = T.tensor3('inputs')
x.tag.test_value = inputs.get_value(borrow=True)
t.tag.test_value = targets.get_value(borrow=True)
# create a set of parameters for a simple RNN
W_xh = theano.shared(
(0.01 * self.rng.uniform(
size=(num_features, 10))).astype(config.floatX),
borrow=True)
W_hh = theano.shared(
(0.01 * self.rng.uniform(size=(10, 10))).astype(config.floatX),
borrow=True)
W_hy = theano.shared(
(0.01 * self.rng.uniform(size=(10, 1))).astype(config.floatX),
borrow=True)
b_h = theano.shared(numpy.zeros(10).astype(config.floatX), borrow=True)
b_y = theano.shared(numpy.zeros(1).astype(config.floatX), borrow=True)
params = [W_xh, W_hh, W_hy, b_h, b_y]
# recurrent function
def step(x_t, h_tm1):
h = T.tanh(T.dot(h_tm1, W_hh) + T.dot(x_t, W_xh) + b_h)
return h
# build recurrent graph
if batch_size == 1:
h_0 = T.alloc(0.0, 10).astype(config.floatX)
else:
h_0 = T.alloc(0.0, batch_size, 10).astype(config.floatX)
h, updates = theano.scan(step,
sequences=[x],
outputs_info=[h_0])
# network output
y = T.dot(h, W_hy) + b_y
# Create Gauss-Newton-Matrix object. Not really of any use here, but I
# need it for Hessian-Free optimization.
gn = GaussNewtonMatrix(y)
# compute MSE
cost = ((t - y) ** 2).sum(axis=1).mean()
# Compute the cost at some other point in the parameter
# space. Not really of any use here, but this is how I do it
# during certain iterations of CG in the HF algorithm. There,
# it's in fact `pi + current update proposal`. For simplicity,
# I just multiply by 2 here.
cost_ = theano.clone(cost,
replace=dict([(pi, 2 * pi) for pi in params]))
# Compute Gauss-Newton-Matrix times some vector `v` which is `p` in CG,
# but for simplicity, I just take the parameters vector because it's
# already there.
Gv = gn(v=params, cost=cost, parameters=params, damp=T.constant(1.0))
# compile Theano function
f = theano.function([], [cost_] + Gv, givens={x: inputs, t: targets},
mode=mode)
# execute
f()
def test_batch(self):
# This runs fine. The batch size is set to something greater than 1,
# i.e. the data is represented by a tensor3 object.
self._run(100, 10, batch_size=5, mode=mode)
def test_nobatch(self):
# This used to give an error due to optimization "scan_merge_inouts".
# The batch size is set to 1 and the data is represented by a matrix.
# As of 2013-10-24, it still triggers an optimization error due to
# "remove_constants_and_unused_inputs_scan".
mode_exc = mode.excluding("remove_constants_and_unused_inputs_scan")
self._run(100, 10, batch_size=1, mode=mode_exc)
class GaussNewtonMatrix(object):
def __init__(self, s):
# `s` is the linear network outputs, i.e. the network output
# without having applied the activation function
self._s = s
def __call__(self, v, cost, parameters, damp):
# compute Gauss-Newton Matrix right-multiplied by `v`
Jv = T.Rop(self._s, parameters, v)
HJv = T.grad(T.sum(T.grad(cost, self._s) * Jv), self._s,
consider_constant=[Jv])
JHJv = T.grad(T.sum(HJv * self._s), parameters,
consider_constant=[HJv, Jv])
# apply Tikhonov damping
JHJv = [JHJvi + damp * vi for JHJvi, vi in zip(JHJv, v)]
return JHJv
...@@ -2579,7 +2579,7 @@ class Alloc(gof.Op): ...@@ -2579,7 +2579,7 @@ class Alloc(gof.Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs return self(eval_points[0], *inputs[1:], **dict(return_list=True))
def do_constant_folding(self, node): def do_constant_folding(self, node):
if not getattr(node.outputs[0], 'clients', []): if not getattr(node.outputs[0], 'clients', []):
...@@ -3275,7 +3275,7 @@ class Rebroadcast(Op): ...@@ -3275,7 +3275,7 @@ class Rebroadcast(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self.make_node(*eval_points).outputs return self(*eval_points, **dict(return_list=True))
def addbroadcast(x, *axes): def addbroadcast(x, *axes):
...@@ -3805,7 +3805,7 @@ class Reshape(Op): ...@@ -3805,7 +3805,7 @@ class Reshape(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs return self(eval_points[0], *inputs[1:], **dict(return_list=True))
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
# inputs[1] can contain at most one value of '-1', meaning the actual # inputs[1] can contain at most one value of '-1', meaning the actual
...@@ -4600,7 +4600,7 @@ class Dot(Op): ...@@ -4600,7 +4600,7 @@ class Dot(Op):
eval_point_values = [ev0, ev1] eval_point_values = [ev0, ev1]
for i in xrange(2): for i in xrange(2):
if eval_point_values[i] and \ if eval_point_values[i] is not None and \
input_values[i].shape != eval_point_values[i].shape: input_values[i].shape != eval_point_values[i].shape:
raise ValueError('input ' + str(i) + ' and eval_point ' + raise ValueError('input ' + str(i) + ' and eval_point ' +
str(i) + ' to Dot.R_op ' str(i) + ' to Dot.R_op '
......
...@@ -273,7 +273,7 @@ class DimShuffle(Op): ...@@ -273,7 +273,7 @@ class DimShuffle(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if None in eval_points: if None in eval_points:
return [None] return [None]
return self.make_node(*eval_points).outputs return self(*eval_points, **dict(return_list=True))
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
input, = inp input, = inp
...@@ -616,7 +616,7 @@ class Elemwise(Op): ...@@ -616,7 +616,7 @@ class Elemwise(Op):
return self.name return self.name
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
outs = self.make_node(*inputs).outputs outs = self(*inputs, **dict(return_list=True))
rval = [None for x in outs] rval = [None for x in outs]
# For each output # For each output
for idx, out in enumerate(outs): for idx, out in enumerate(outs):
...@@ -1882,7 +1882,7 @@ class Sum(CAReduceDtype): ...@@ -1882,7 +1882,7 @@ class Sum(CAReduceDtype):
# part of self # part of self
if None in eval_points: if None in eval_points:
return [None] return [None]
return self.make_node(*eval_points).outputs return self(*eval_points, **dict(return_list=True))
def __str__(self): def __str__(self):
if self.axis is None: if self.axis is None:
......
...@@ -263,10 +263,11 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -263,10 +263,11 @@ def inplace_elemwise_optimizer_op(OP):
scalar.transfer_type( scalar.transfer_type(
*[inplace_pattern.get(i, None) \ *[inplace_pattern.get(i, None) \
for i in xrange(len(node.outputs))])) for i in xrange(len(node.outputs))]))
new = OP(new_scal, inplace_pattern).make_node( new_outputs = OP(new_scal, inplace_pattern)(
*node.inputs) *node.inputs, **dict(return_list=True))
new_node = new_outputs[0].owner
for r, new_r in zip(node.outputs, new.outputs): for r, new_r in zip(node.outputs, new_outputs):
fgraph.replace(r, new_r, fgraph.replace(r, new_r,
reason="inplace_elemwise_optimizer") reason="inplace_elemwise_optimizer")
nb_change_no_validate += 1 nb_change_no_validate += 1
...@@ -284,7 +285,7 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -284,7 +285,7 @@ def inplace_elemwise_optimizer_op(OP):
fgraph.revert(chk) fgraph.revert(chk)
continue continue
candidate_inputs.remove(candidate_input) candidate_inputs.remove(candidate_input)
node = new node = new_node
baseline = inplace_pattern baseline = inplace_pattern
break break
......
...@@ -883,7 +883,7 @@ class Subtensor(Op): ...@@ -883,7 +883,7 @@ class Subtensor(Op):
# (they should be defaulted to zeros_like by the global R_op) # (they should be defaulted to zeros_like by the global R_op)
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs return self(eval_points[0], *inputs[1:], **dict(return_list=True))
class SubtensorPrinter: class SubtensorPrinter:
...@@ -1414,8 +1414,8 @@ class IncSubtensor(Op): ...@@ -1414,8 +1414,8 @@ class IncSubtensor(Op):
return [None] return [None]
# Again we ignore eval points for indices because incsubtensor is # Again we ignore eval points for indices because incsubtensor is
# not differentiable wrt to those # not differentiable wrt to those
return self.make_node(eval_points[0], eval_points[1], return self(eval_points[0], eval_points[1], *inputs[2:],
*inputs[2:]).outputs **dict(return_list=True))
def connection_pattern(self, node): def connection_pattern(self, node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论