提交 2e5fcea2 authored 作者: lamblin's avatar lamblin

Merge pull request #1054 from pascanur/new_optimizations_scan

New optimizations scan
......@@ -1375,18 +1375,18 @@ class Scan(PureOp):
def compute_gradient(y, g_y):
if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y "
"has type "+str(g_y.type))
"has type " + str(g_y.type))
wrt = [x for x in theano.gof.graph.inputs([y])
wrt = [x for x in theano.gof.graph.inputs([y])
if x in diff_inputs]
grads = gradient.grad(
cost = None,
known_grads = {y : g_y },
grads = gradient.grad(
cost=None,
known_grads={y: g_y},
wrt=wrt, consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None')
gmp = dict(zip(wrt, grads))
rval = [gmp.get(p, None) for p in diff_inputs]
rval = [gmp.get(p, None) for p in diff_inputs]
return rval
dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs]
......@@ -1727,7 +1727,7 @@ class Scan(PureOp):
node = outs[0].owner
for idx in xrange(self.n_shared_outs):
disconnected = True
connected_flags = self.connection_pattern(node)[idx+start]
connected_flags = self.connection_pattern(node)[idx + start]
for dC_dout, connected in zip(dC_douts, connected_flags):
if (not isinstance(dC_dout.type, DisconnectedType) and
connected):
......
......@@ -191,8 +191,6 @@ def get_updates_and_outputs(ls):
this function know how to put it in that order?
"""
def is_outputs(elem):
if (isinstance(elem, (list, tuple)) and
all([isinstance(x, theano.Variable) for x in elem])):
......@@ -206,7 +204,7 @@ def get_updates_and_outputs(ls):
# Make sure the updates will be applied in a deterministic order
if not isinstance(elem, gof.python25.OrderedDict):
warnings.warn("Expected OrderedDict or OrderedUpdates, got "\
+str(type(elem))+". This can make your script non-"
+ str(type(elem)) + ". This can make your script non-"
"deterministic.")
return True
# Dictionaries can be given as lists of tuples
......@@ -253,7 +251,6 @@ def get_updates_and_outputs(ls):
'values, you can use `tensor.constant` to turn them into '
'Theano variables.')
if is_outputs(ls):
return None, _list(ls), OrderedDict()
if is_updates(ls):
......@@ -389,7 +386,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
elif (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and
dx.dtype == dy.dtype and
dx.type.dtype == dy.type.dtype and
dx.data.shape == dy.data.shape):
return False
else:
......@@ -413,7 +410,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and
dx.dtype == dy.dtype and
dx.type.dtype == dy.type.dtype and
dx.data.shape == dy.data.shape):
return False
else:
......
......@@ -2346,9 +2346,7 @@ class T_Scan(unittest.TestCase):
# this new assert is here to test if scan_merging works ..
nb_scan = len([n for n in topo
if isinstance(n.op, theano.scan_module.scan_op.Scan)])
# For this to work we need an optimization that it will be pushed in
# a new pull request
self.assertTrue(nb_scan == 2)
self.assertTrue(nb_scan == 1)
nb_shape_i = len([n for n in topo
if isinstance(n.op, theano.tensor.opt.Shape_i)])
if theano.config.mode != 'FAST_COMPILE':
......@@ -2364,7 +2362,8 @@ class T_Scan(unittest.TestCase):
sx, upx = theano.scan(sum, sequences=[x])
sy, upy = theano.scan(sum, sequences=[y])
f = theano.function([x, y], [sx, sy], mode=mode_with_opt)
f = theano.function([x, y], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n: isinstance(
n.op, theano.scan_module.scan_op.Scan), topo)
......@@ -2373,7 +2372,8 @@ class T_Scan(unittest.TestCase):
sx, upx = theano.scan(sum, sequences=[x], n_steps=2)
sy, upy = theano.scan(sum, sequences=[y], n_steps=3)
f = theano.function([x, y], [sx, sy], mode=mode_with_opt)
f = theano.function([x, y], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n: isinstance(
n.op, theano.scan_module.scan_op.Scan), topo)
......@@ -2382,7 +2382,8 @@ class T_Scan(unittest.TestCase):
sx, upx = theano.scan(sum, sequences=[x], n_steps=4)
sy, upy = theano.scan(sum, sequences=[y], n_steps=4)
f = theano.function([x, y], [sx, sy], mode=mode_with_opt)
f = theano.function([x, y], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n: isinstance(
n.op, theano.scan_module.scan_op.Scan), topo)
......@@ -2391,7 +2392,8 @@ class T_Scan(unittest.TestCase):
sx, upx = theano.scan(sum, sequences=[x])
sy, upy = theano.scan(sum, sequences=[x])
f = theano.function([x], [sx, sy], mode=mode_with_opt)
f = theano.function([x], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n:
isinstance(n.op, theano.scan_module.scan_op.Scan), topo)
......@@ -2401,7 +2403,7 @@ class T_Scan(unittest.TestCase):
sy, upy = theano.scan(sum, sequences=[x], mode='FAST_COMPILE')
f = theano.function([x], [sx, sy],
mode=mode_with_opt)
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n:
isinstance(n.op, theano.scan_module.scan_op.Scan), topo)
......@@ -2410,7 +2412,8 @@ class T_Scan(unittest.TestCase):
sx, upx = theano.scan(sum, sequences=[x])
sy, upy = theano.scan(sum, sequences=[x], truncate_gradient=1)
f = theano.function([x], [sx, sy], mode=mode_with_opt)
f = theano.function([x], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n:
isinstance(n.op, theano.scan_module.scan_op.Scan), topo)
......@@ -2820,12 +2823,12 @@ class T_Scan(unittest.TestCase):
vx = numpy.zeros((50,), dtype=theano.config.floatX)
vx[23] = 4
out, out2 = f(vx)
print 'len_out', len(out)
assert len(out) == 24
assert numpy.all(out2 == vx + 2)
lssc = [x for x in f.maker.fgraph.toposort()
if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert len(lssc) == 2
# One scan node gets optimnized out
assert len(lssc) == 1
@dec.knownfailureif(True,
("This test fails because not typed outputs_info "
......@@ -3303,6 +3306,70 @@ class T_Scan(unittest.TestCase):
theano.scan_module.scan_op.Scan)]
assert len(scan_nodes) == 1
def test_eliminate_seqs(self):
U = tensor.vector('U')
sh = theano.shared(asarrayX(2.))
x1 = tensor.vector('x1')
x2 = tensor.scalar('x2')
def rec_fn(*args):
u_t = args[0]
return [(u_t + 1, # mitsot
u_t + 2, # sitsot
u_t + 3), # nitsot
{sh: u_t + 4}] # shared
[X1, X2, X3], updates = theano.scan(
rec_fn,
U,
[dict(initial=x1, taps=[-1, -3]), x2, None],
n_steps=None,
truncate_gradient=-1,
go_backwards=False)
f = theano.function([U, x1, x2], [X1, X2, X3],
updates=updates,
mode=theano.Mode(linker='py'),
allow_input_downcast=True)
rng = numpy.random.RandomState(utt.fetch_seed())
v_u = asarrayX(rng.uniform(size=(5,)))
outs = f(v_u, [0, 0, 0], 0)
assert numpy.allclose(outs[0], v_u + 1)
assert numpy.allclose(outs[1], v_u + 2)
assert numpy.allclose(outs[2], v_u + 3)
assert numpy.allclose(sh.get_value(), v_u[-1] + 4)
def test_eliminate_nonseqs(self):
W = tensor.scalar('W')
sh = theano.shared(asarrayX(2.))
x1 = tensor.vector('x1')
x2 = tensor.scalar('x2')
def rec_fn(*args):
w = args[-1]
return [(w + 1., # mitsot
w + 2., # sitsot
w + 3.), # nitsot
{sh: w + 4.}] # shared
[X1, X2, X3], updates = theano.scan(
rec_fn,
[],
[dict(initial=x1, taps=[-1, -3]), x2, None],
W,
n_steps=5,
truncate_gradient=-1,
go_backwards=False)
f = theano.function([W, x1, x2], [X1, X2, X3],
updates=updates,
mode=theano.Mode(linker='py'),
allow_input_downcast=True)
rng = numpy.random.RandomState(utt.fetch_seed())
v_w = asarrayX(rng.uniform())
outs = f(v_w, [0, 0, 0], 0)
assert numpy.allclose(outs[0], v_w + 1)
assert numpy.allclose(outs[1], v_w + 2)
assert numpy.allclose(outs[2], v_w + 3)
assert numpy.allclose(sh.get_value(), v_w + 4)
def test_speed():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论