提交 dcc47b47 authored 作者: Samira Shabanian's avatar Samira Shabanian

sumsqr2dot added to the opt.py

上级 8bbc630c
......@@ -3574,6 +3574,39 @@ def local_join_make_vector(node):
return [ret]
#################
# speed/memory #
#################
@register_specialize
@gof.local_optimizer([T.elemwise.Sum])
def local_sumsqr2dot(node):
"""
This optimization detects T.sqr( W.dimshuffle('x',0,1) * G.dimshuffle(0,'x',1) ).sum(axis=(1,2))
and converts this to T.dot(T.sqr(G), T.sqr(W).sum(axis=0)).
"""
if (isinstance(node.op, T.elemwise.Sum) and
isinstance(node.op.scalar_op, theano.scalar.basic.Add) and node.op.axis == (1, 2)):
in1 = node.inputs[0]
out = node.outputs[0]
if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Sqr)):
in_sqr = in1.owner.inputs[0]
if (in_sqr.owner and isinstance(in_sqr.owner.op, T.Elemwise) and
isinstance(in_sqr.owner.op.scalar_op, theano.scalar.basic.Mul)):
in_mul1, in_mul2 = in_sqr.owner.inputs
if (isinstance(in_mul1.owner.op, T.elemwise.DimShuffle) and in_mul1.owner.op.new_order == ('x', 0, 1) and
isinstance(in_mul2.owner.op, T.elemwise.DimShuffle) and in_mul2.owner.op.new_order == (0, 'x', 1)):
W = in_mul1.owner.inputs[0]
G = in_mul2.owner.inputs[0]
new_out = T.dot(T.sqr(G), T.sqr(W).sum(axis=0))
if new_out.dtype != out.dtype:
new_out = T.cast(new_out, dtype=out.dtype)
return [new_out]
#################
# Exp stability #
#################
......
......@@ -6173,6 +6173,24 @@ def test_local_zero_div():
assert theano.tensor.get_scalar_constant_value(output) == 0
def test_local_sumsqr2dot():
G = matrix('G')
W = matrix('W')
y = T.sqr( W.dimshuffle('x',0,1) * G.dimshuffle(0,'x',1) ).sum(axis=(1,2))
f = function([W, G], y)
w_val = numpy.random.rand(4, 3).astype(config.floatX)
g_val = numpy.random.rand(5, 3).astype(config.floatX)
f_val = f(w_val, g_val)
f_test = function([W,G], T.dot(T.sqr(G), T.sqr(W).sum(axis=0)))
assert numpy.allclose(f_val, f_test(w_val, g_val))
assert any(isinstance(n.op, theano.tensor.basic.Dot) for n in f.maker.fgraph.toposort())
def test_local_expm1():
x = matrix('x')
u = T.scalar('u')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论