提交 780d7a43 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix push dot out optimization

上级 dafaffaa
...@@ -1463,10 +1463,10 @@ class PushOutDot1(gof.Optimizer): ...@@ -1463,10 +1463,10 @@ class PushOutDot1(gof.Optimizer):
inp2 = x.owner.inputs[1] inp2 = x.owner.inputs[1]
if inp1 in seqs or inp2 in seqs: if inp1 in seqs or inp2 in seqs:
new_scan_out = inp2 new_scan_out = inp1
if inp2 in seqs: if inp1 in seqs:
new_scan_out = inp1 new_scan_out = inp2
idx = sitsot_outs.index(out) idx = sitsot_outs.index(out)
# We've found our pattern and need to construct a new # We've found our pattern and need to construct a new
# scan node to replace this one. For this we need to # scan node to replace this one. For this we need to
...@@ -1535,6 +1535,8 @@ class PushOutDot1(gof.Optimizer): ...@@ -1535,6 +1535,8 @@ class PushOutDot1(gof.Optimizer):
outer_non_seqs) outer_non_seqs)
new_outs = new_op(*_scan_inputs) new_outs = new_op(*_scan_inputs)
if type(new_outs) not in (list, tuple):
new_outs = [new_outs]
# We need now to pair correctly the new outputs with the # We need now to pair correctly the new outputs with the
# old ones # old ones
......
...@@ -3322,6 +3322,20 @@ class T_Scan(unittest.TestCase): ...@@ -3322,6 +3322,20 @@ class T_Scan(unittest.TestCase):
# scan could not detect the connection between `m2` and `x` # scan could not detect the connection between `m2` and `x`
tensor.grad(m2.sum(), m) tensor.grad(m2.sum(), m)
def test_dot_optimization(self):
A = tensor.matrix('A')
B = tensor.matrix('B')
S, _ = theano.scan(lambda x1,x2, u: u + tensor.dot(x1,x2),
sequences = [A.dimshuffle(0, 1, 'x'),
B.dimshuffle(0,'x', 1)],
outputs_info=[tensor.zeros_like(A)])
f = theano.function([A,B], S.owner.inputs[0][-1])
rng = numpy.random.RandomState(utt.fetch_seed())
vA = rng.uniform(size=(5,5))
vB = rng.uniform(size=(5,5))
assert numpy.allclose(f(vA, vB), numpy.dot(vA.T, vB))
def test_pregreedy_optimizer(self): def test_pregreedy_optimizer(self):
W = tensor.zeros((5, 4)) W = tensor.zeros((5, 4))
bv = tensor.zeros((5,)) bv = tensor.zeros((5,))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论