提交 8cc83780 authored 作者: delallea's avatar delallea

Merge pull request #79 from dwf/blas_pep8

STY: various blas/test_blas fixes
...@@ -1206,13 +1206,16 @@ class GemmOptimizer(Optimizer): ...@@ -1206,13 +1206,16 @@ class GemmOptimizer(Optimizer):
try: try:
env.replace_all_validate( env.replace_all_validate(
zip(node.outputs, new_outputs), zip(node.outputs, new_outputs),
reason = 'GemmOptimizer') reason='GemmOptimizer'
)
did_something = True did_something = True
break break
except InconsistencyError, e: except InconsistencyError, e:
#TODO: retry other applications of gemm (see comment in _gemm_from_node # TODO: retry other applications of gemm (see comment
# in _gemm_from_node)
pass pass
class Dot22(GemmRelated): class Dot22(GemmRelated):
"""Compute a matrix-matrix product. """Compute a matrix-matrix product.
This is a specialization of the more general Dot() This is a specialization of the more general Dot()
...@@ -1227,7 +1230,7 @@ class Dot22(GemmRelated): ...@@ -1227,7 +1230,7 @@ class Dot22(GemmRelated):
raise TypeError('dtype mismatch to Dot22') raise TypeError('dtype mismatch to Dot22')
bz = (x.type.broadcastable[0], y.type.broadcastable[1]) bz = (x.type.broadcastable[0], y.type.broadcastable[1])
outputs = [T.tensor(x.type.dtype, bz)] outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y], outputs) return Apply(self, [x, y], outputs)
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, y = inp x, y = inp
...@@ -1235,9 +1238,11 @@ class Dot22(GemmRelated): ...@@ -1235,9 +1238,11 @@ class Dot22(GemmRelated):
try: try:
z[0] = numpy.asarray(numpy.dot(x, y)) z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e: except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that # The error raised by numpy has no shape information, we mean to
# add that
e.args = e.args + (x.shape, y.shape) e.args = e.args + (x.shape, y.shape)
raise raise
def __str__(self): def __str__(self):
return "_dot22" return "_dot22"
...@@ -1250,10 +1255,12 @@ class Dot22(GemmRelated): ...@@ -1250,10 +1255,12 @@ class Dot22(GemmRelated):
npy_intp dims[2]; npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0]; dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1]; dims[1] = %(_y)s->dimensions[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s); %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
type_num_%(_x)s);
//fprintf(stderr, "Dot Allocating %%i %%i\\n", dims[0], dims[1]); //fprintf(stderr, "Dot Allocating %%i %%i\\n", dims[0], dims[1]);
if(!%(_zout)s) { if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output"); PyErr_SetString(PyExc_MemoryError,
"failed to alloc dot22 output");
%(fail)s %(fail)s
} }
} }
...@@ -1270,16 +1277,19 @@ class Dot22(GemmRelated): ...@@ -1270,16 +1277,19 @@ class Dot22(GemmRelated):
double a = 1.0; double a = 1.0;
double b = 0.0; double b = 0.0;
""" """
def c_code(self, node, name, inp, out, sub): #DEBUG
def c_code(self, node, name, inp, out, sub): # DEBUG
_x, _y = inp _x, _y = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__) % self.__class__.__name__)
if len(self.c_libraries())<=0: if len(self.c_libraries()) <= 0:
return super(Dot22, self).c_code(node, name, (_x, _y), (_zout, ), sub) return super(Dot22, self).c_code(node, name, (_x, _y),
(_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code return full_code
def c_code_cache_version(self): def c_code_cache_version(self):
gv = self.build_gemm_version() gv = self.build_gemm_version()
if gv: if gv:
...@@ -1436,6 +1446,7 @@ optdb.register('InplaceBlasOpt', ...@@ -1436,6 +1446,7 @@ optdb.register('InplaceBlasOpt',
blas_opt_inplace, blas_opt_inplace,
70.0, 'fast_run', 'inplace') 70.0, 'fast_run', 'inplace')
class Dot22Scalar(GemmRelated): class Dot22Scalar(GemmRelated):
"""Compute a matrix-matrix product. """Compute a matrix-matrix product.
This is a specialization of the more general Dot() This is a specialization of the more general Dot()
...@@ -1473,6 +1484,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1473,6 +1484,7 @@ class Dot22Scalar(GemmRelated):
# The error raised by numpy has no shape information, we mean to add that # The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape) e.args = e.args + (x.shape, y.shape)
raise raise
def __str__(self): def __str__(self):
return "_dot22scalar" return "_dot22scalar"
...@@ -1492,6 +1504,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1492,6 +1504,7 @@ class Dot22Scalar(GemmRelated):
#undef REAL #undef REAL
float b = 0.0; float b = 0.0;
""" """
case_double_ab_constants = """ case_double_ab_constants = """
#define REAL double #define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT) double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
...@@ -1500,16 +1513,18 @@ class Dot22Scalar(GemmRelated): ...@@ -1500,16 +1513,18 @@ class Dot22Scalar(GemmRelated):
#undef REAL #undef REAL
double b = 0.0; double b = 0.0;
""" """
def c_code(self, node, name, inp, out, sub): #DEBUG def c_code(self, node, name, inp, out, sub): #DEBUG
_x, _y, _a = inp _x, _y, _a = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__) % self.__class__.__name__)
if len(self.c_libraries())<=0: if len(self.c_libraries()) <= 0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y), (_zout, ), sub) return super(Dot22Scalar, self).c_code(node, name, (_x, _y), (_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code return full_code
def c_code_cache_version(self): def c_code_cache_version(self):
gv = self.build_gemm_version() gv = self.build_gemm_version()
if gv: if gv:
...@@ -1558,10 +1573,10 @@ def local_dot22_to_dot22scalar(node): ...@@ -1558,10 +1573,10 @@ def local_dot22_to_dot22scalar(node):
for i,x in enumerate(m.owner.inputs): for i,x in enumerate(m.owner.inputs):
if _as_scalar(x) and (theano.scalar.upcast(x.type.dtype,d.type.dtype) if _as_scalar(x) and (theano.scalar.upcast(x.type.dtype,d.type.dtype)
== d.type.dtype): == d.type.dtype):
scalar_idx=i scalar_idx = i
break break
if scalar_idx<0: if scalar_idx < 0:
_logger.info('Not optimizing dot22 with inputs %s %s, as the type ' _logger.info('Not optimizing dot22 with inputs %s %s, as the type '
'of the scalar cannot be upcasted to the matrix type', 'of the scalar cannot be upcasted to the matrix type',
node.inputs, [x.type for x in node.inputs]) node.inputs, [x.type for x in node.inputs])
...@@ -1593,7 +1608,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -1593,7 +1608,7 @@ def local_dot22_to_dot22scalar(node):
== d.type.dtype)): == d.type.dtype)):
scalar_idx = i scalar_idx = i
break break
if scalar_idx<0: if scalar_idx < 0:
_logger.info('Not optimizing dot22 with inputs %s %s, as the type ' _logger.info('Not optimizing dot22 with inputs %s %s, as the type '
'of the scalar cannot be upcasted to the matrix type', 'of the scalar cannot be upcasted to the matrix type',
node.inputs, [x.type for x in node.inputs]) node.inputs, [x.type for x in node.inputs])
......
...@@ -5,7 +5,8 @@ import theano.tensor as T ...@@ -5,7 +5,8 @@ import theano.tensor as T
#from theano.gof import Env #from theano.gof import Env
from theano.printing import pp from theano.printing import pp
import numpy, theano import numpy
import theano
from numpy import (arange, array, common_type, complex64, complex128, float32, from numpy import (arange, array, common_type, complex64, complex128, float32,
float64, newaxis, shape, transpose, zeros) float64, newaxis, shape, transpose, zeros)
from numpy.testing import assert_, assert_array_almost_equal from numpy.testing import assert_, assert_array_almost_equal
...@@ -13,10 +14,11 @@ from numpy.testing import assert_, assert_array_almost_equal ...@@ -13,10 +14,11 @@ from numpy.testing import assert_, assert_array_almost_equal
#from numpy.testing.noseclasses import KnownFailureTest #from numpy.testing.noseclasses import KnownFailureTest
#from theano.tensor.blas import * #from theano.tensor.blas import *
from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar, _is_real_matrix, from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar,
_gemm_canonicalize, _factor_canonicalized, Gemm, Gemv, gemm_inplace, gemm_no_inplace, _is_real_matrix, _gemm_canonicalize,
InconsistencyError, _factor_canonicalized, Gemm, Gemv,
Ger, ger, ger_destructive) gemm_inplace, gemm_no_inplace,
InconsistencyError, Ger, ger, ger_destructive)
from unittest import TestCase from unittest import TestCase
from theano.tests import unittest_tools from theano.tests import unittest_tools
from copy import copy, deepcopy from copy import copy, deepcopy
...@@ -29,7 +31,8 @@ import theano.tensor.blas_scipy ...@@ -29,7 +31,8 @@ import theano.tensor.blas_scipy
if config.mode == 'FAST_COMPILE': if config.mode == 'FAST_COMPILE':
mode_not_fast_compile = 'FAST_RUN' mode_not_fast_compile = 'FAST_RUN'
else: mode_not_fast_compile = config.mode else:
mode_not_fast_compile = config.mode
mode_blas_opt = theano.compile.get_default_mode().including('BlasOpt', 'specialize') mode_blas_opt = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
...@@ -675,9 +678,9 @@ def test_inplace1(): ...@@ -675,9 +678,9 @@ def test_inplace1():
def test_dot22(): def test_dot22():
for dtype1 in ['float32', 'float64', 'complex64', 'complex128']: for dtype1 in ['float32', 'float64', 'complex64', 'complex128']:
a=T.matrix(dtype = dtype1) a = T.matrix(dtype=dtype1)
for dtype2 in ['float32', 'float64', 'complex64', 'complex128']: for dtype2 in ['float32', 'float64', 'complex64', 'complex128']:
b=T.matrix(dtype = dtype2) b = T.matrix(dtype=dtype2)
f = theano.function([a,b],T.dot(a,b),mode=mode_blas_opt) f = theano.function([a,b],T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
if dtype1 == dtype2: if dtype1 == dtype2:
...@@ -691,12 +694,12 @@ def test_dot22(): ...@@ -691,12 +694,12 @@ def test_dot22():
bv=rng.uniform(size=b_shp).astype(dtype2) bv=rng.uniform(size=b_shp).astype(dtype2)
f(av,bv) f(av,bv)
cmp((3,4),(4,5)) cmp((3, 4), (4, 5))
cmp((0,4),(4,5)) cmp((0, 4), (4, 5))
cmp((3,0),(0,5)) cmp((3, 0), (0, 5))
cmp((3,4),(4,0)) cmp((3, 4), (4, 0))
cmp((0,4),(4,0)) cmp((0, 4), (4, 0))
cmp((0,0),(0,0)) cmp((0, 0), (0, 0))
def test_dot22scalar(): def test_dot22scalar():
## including does not seem to work for 'local_dot_to_dot22' and ## including does not seem to work for 'local_dot_to_dot22' and
...@@ -706,11 +709,11 @@ def test_dot22scalar(): ...@@ -706,11 +709,11 @@ def test_dot22scalar():
#m = theano.compile.get_default_mode().including('BlasOpt', 'specialize') #m = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
for dtype1 in ['complex64', 'complex128']: for dtype1 in ['complex64', 'complex128']:
a=T.matrix('a', dtype = dtype1) a = T.matrix('a', dtype=dtype1)
for dtype2 in ['complex64', 'complex128']: for dtype2 in ['complex64', 'complex128']:
b=T.matrix('b', dtype = dtype2) b = T.matrix('b', dtype=dtype2)
for dtype3 in ['complex64', 'complex128']: for dtype3 in ['complex64', 'complex128']:
c=T.matrix('c', dtype = dtype3) c = T.matrix('c', dtype=dtype3)
for dtype4 in ['complex64', 'complex128']: for dtype4 in ['complex64', 'complex128']:
cst = theano.tensor.basic.constant(.2, dtype=dtype4) cst = theano.tensor.basic.constant(.2, dtype=dtype4)
cst2 = theano.tensor.basic.constant(.1, dtype=dtype4) cst2 = theano.tensor.basic.constant(.1, dtype=dtype4)
...@@ -978,11 +981,11 @@ def matrixmultiply(a, b): ...@@ -978,11 +981,11 @@ def matrixmultiply(a, b):
class BaseGemv(object): class BaseGemv(object):
def get_data(self,x_stride=1,y_stride=1): def get_data(self,x_stride=1,y_stride=1):
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
mult = array(1, dtype = self.dtype) mult = array(1, dtype=self.dtype)
if self.dtype in [complex64,complex128]: if self.dtype in [complex64,complex128]:
mult = array(1+1j, dtype = self.dtype) mult = array(1 + 1j, dtype=self.dtype)
alpha = array(1., dtype = self.dtype) * mult alpha = array(1., dtype=self.dtype) * mult
beta = array(1., dtype = self.dtype) * mult beta = array(1., dtype=self.dtype) * mult
a = rng.randn(3,3).astype(self.dtype) * mult a = rng.randn(3,3).astype(self.dtype) * mult
x = arange(shape(a)[0]*x_stride,dtype=self.dtype) * mult x = arange(shape(a)[0]*x_stride,dtype=self.dtype) * mult
y = arange(shape(a)[1]*y_stride,dtype=self.dtype) * mult y = arange(shape(a)[1]*y_stride,dtype=self.dtype) * mult
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论