提交 7927f4a0 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4263 from shabanian/opt

sumsqr2dot added to the opt.py
...@@ -3539,6 +3539,39 @@ def local_join_make_vector(node): ...@@ -3539,6 +3539,39 @@ def local_join_make_vector(node):
return [ret] return [ret]
#################
# speed/memory #
#################
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.elemwise.Sum])
def local_sumsqr2dot(node):
"""
This optimization detects T.sqr( W.dimshuffle('x',0,1) * G.dimshuffle(0,'x',1) ).sum(axis=(1,2))
and converts this to T.dot(T.sqr(G), T.sqr(W).sum(axis=0)).
"""
if (isinstance(node.op, T.elemwise.Sum) and
isinstance(node.op.scalar_op, theano.scalar.basic.Add) and node.op.axis == (1, 2)):
in1 = node.inputs[0]
out = node.outputs[0]
if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Sqr)):
in_sqr = in1.owner.inputs[0]
if (in_sqr.owner and isinstance(in_sqr.owner.op, T.Elemwise) and
isinstance(in_sqr.owner.op.scalar_op, theano.scalar.basic.Mul) and len(in_sqr.owner.inputs) == 2):
in_mul1, in_mul2 = in_sqr.owner.inputs
if (isinstance(in_mul1.owner.op, T.elemwise.DimShuffle) and in_mul1.owner.op.new_order == ('x', 0, 1) and
isinstance(in_mul2.owner.op, T.elemwise.DimShuffle) and in_mul2.owner.op.new_order == (0, 'x', 1)):
W = in_mul1.owner.inputs[0]
G = in_mul2.owner.inputs[0]
new_out = T.dot(T.sqr(G), T.sqr(W).sum(axis=0))
if new_out.dtype != out.dtype:
new_out = T.cast(new_out, dtype=out.dtype)
return [new_out]
################# #################
# Exp stability # # Exp stability #
################# #################
......
...@@ -6173,6 +6173,27 @@ def test_local_zero_div(): ...@@ -6173,6 +6173,27 @@ def test_local_zero_div():
assert theano.tensor.get_scalar_constant_value(output) == 0 assert theano.tensor.get_scalar_constant_value(output) == 0
def test_local_sumsqr2dot():
G = matrix('G')
W = matrix('W')
y = T.sqr(W.dimshuffle('x', 0, 1) * G.dimshuffle(0, 'x', 1)).sum(axis=(1, 2))
MODE = theano.compile.get_default_mode().including('local_sumsqr2dot')
f = function([W, G], y, mode=MODE)
w_val = numpy.random.rand(4, 3).astype(config.floatX)
g_val = numpy.random.rand(5, 3).astype(config.floatX)
f_val = f(w_val, g_val)
f_test = numpy.dot(numpy.square(g_val), numpy.square(w_val).sum(axis=0))
assert numpy.allclose(f_val, f_test)
assert any(isinstance(n.op, (tensor.basic.Dot, tensor.blas.Dot22,
tensor.blas.Gemv, tensor.blas_c.CGemv))
for n in f.maker.fgraph.toposort())
def test_local_expm1(): def test_local_expm1():
x = matrix('x') x = matrix('x')
u = T.scalar('u') u = T.scalar('u')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论