提交 d12e8b81 authored 作者: James Bergstra's avatar James Bergstra

Ger optimizations.

They are all working except the outer one because outer doesn't turn into a GEMM yet.
上级 1d127d63
...@@ -124,6 +124,7 @@ from theano.gof.python25 import all, any ...@@ -124,6 +124,7 @@ from theano.gof.python25 import all, any
import theano.scalar import theano.scalar
import basic as T import basic as T
from theano.tensor.blas_headers import blas_header_text #, cblas_header_text from theano.tensor.blas_headers import blas_header_text #, cblas_header_text
from theano.tensor.opt import local_dimshuffle_lift
_logger = logging.getLogger('theano.tensor.blas') _logger = logging.getLogger('theano.tensor.blas')
...@@ -1333,6 +1334,35 @@ def local_gemm_to_gemv(node): ...@@ -1333,6 +1334,35 @@ def local_gemm_to_gemv(node):
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
return [r.dimshuffle(0, 'x')] return [r.dimshuffle(0, 'x')]
@local_optimizer([gemm_no_inplace])
def local_gemm_to_ger(node):
"""GEMM computing an outer-product -> GER
"""
if node.op == gemm_no_inplace:
z, a, x, y, b = node.inputs
if x.broadcastable[1] and y.broadcastable[0]:
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
try:
bval = T.get_constant_value(b)
except TypeError:
# b isn't a constant, GEMM is doing useful pre-scaling
return
if bval == 1: # best case a natural GER
rval = Ger(destructive=False)(z, a, xv, yv)
return [rval]
elif bval == 0: # GER on zeros_like should be faster than GEMM
zeros = T.alloc(
numpy.asarray(0, dtype=x.dtype),
x.shape[0], y.shape[1])
rval = Ger(destructive=False)(zeros, a, xv, yv)
return [rval]
else:
# if bval is another constant, then z is being usefully
# pre-scaled and GER isn't really the right tool for the job.
return
################################# #################################
# #
...@@ -1354,9 +1384,12 @@ blas_optdb.register('local_dot_to_gemm', ...@@ -1354,9 +1384,12 @@ blas_optdb.register('local_dot_to_gemm',
GemmOptimizer(), GemmOptimizer(),
10, 'fast_run') 10, 'fast_run')
blas_optdb.register('local_gemm_to_gemv', blas_optdb.register('local_gemm_to_gemv',
EquilibriumOptimizer([local_gemm_to_gemv], max_use_ratio=5), EquilibriumOptimizer([local_gemm_to_gemv, local_gemm_to_ger,
local_dimshuffle_lift],
max_use_ratio=5),
15, 'fast_run') 15, 'fast_run')
# After destroyhandler is in but before we try to make elemwise things inplace # After destroyhandler is in but before we try to make elemwise things inplace
# Try to make gemm inplace # Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71) # Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
......
#from nose.plugins.skip import SkipTest #from nose.plugins.skip import SkipTest
#import traceback #import traceback
import sys
import theano.tensor as T import theano.tensor as T
#from theano.gof import Env #from theano.gof import Env
from theano.printing import pp from theano.printing import pp
...@@ -1263,3 +1264,71 @@ class TestGer_make_thunk(TestCase): ...@@ -1263,3 +1264,71 @@ class TestGer_make_thunk(TestCase):
def test_c128_1_9(s): return s.given_dtype('complex128', 1, 9) def test_c128_1_9(s): return s.given_dtype('complex128', 1, 9)
# TODO: Refactor and add to this base class as we refactor test code.
class TestOptimizationMixin(object):
def assertFunctionContains(self, f, op, min=1, max=sys.maxint):
toposort = f.maker.env.toposort()
matches = [node for node in toposort if node.op == op]
assert (min <= len(matches) <= max), toposort
def assertFunctionContains0(self, f, op):
return assertFunctionContains(f, op, min=0, max=0)
def assertFunctionContains1(self, f, op):
return assertFunctionContains(f, op, min=1, max=1)
def assertFunctionContainsN(self, f, op, N):
return assertFunctionContains(f, op, min=N, max=N)
class TestGer_local_gemm_to_ger(TestCase, TestOptimizationMixin):
def setUp(self):
self.mode = theano.Mode(optimizer='fast_run')
dtype = self.dtype = 'float64' # optimization isn't dtype-dependent
self.A = T.tensor(dtype=dtype, broadcastable=(False, False))
self.a = T.tensor(dtype=dtype, broadcastable=())
self.x = T.tensor(dtype=dtype, broadcastable=(False,))
self.y = T.tensor(dtype=dtype, broadcastable=(False,))
def function(self, inputs, outputs):
return theano.function(inputs, outputs, self.mode)
def b(self, bval):
return T.as_tensor_variable(numpy.asarray(bval, dtype=self.dtype))
def test_b_0_triggers_ger(self):
assert T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(
self.A, self.a, self.x.dimshuffle(0,'x'),
self.y.dimshuffle('x', 0), self.b(0)).owner)
def test_b_1_triggers_ger(self):
assert T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(
self.A, self.a, self.x.dimshuffle(0,'x'),
self.y.dimshuffle('x', 0), self.b(1)).owner)
def test_b_other_does_not_triggers_ger(self):
assert not T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(
self.A, self.a, self.x.dimshuffle(0,'x'),
self.y.dimshuffle('x', 0), self.b(1.5)).owner)
def test_outer(self):
f = self.function([self.x, self.y], T.outer(self.x, self.y))
self.assertFunctionContains(f, ger_destructive)
def test_A_plus_outer(self):
f = self.function([self.A, self.x, self.y],
self.A + T.outer(self.x, self.y))
self.assertFunctionContains(f, ger)
def test_A_plus_scaled_outer(self):
f = self.function([self.A, self.x, self.y],
self.A + 0.1 * T.outer(self.x, self.y))
self.assertFunctionContains(f, ger)
def test_scaled_A_plus_scaled_outer(self):
f = self.function([self.A, self.x, self.y],
0.2 * self.A + 0.1 * T.outer(self.x, self.y))
self.assertFunctionContains(f, gemm_no_inplace)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论