提交 4e36f43f authored 作者: David Warde-Farley's avatar David Warde-Farley

STY: various blas/test_blas fixes

上级 a8646fdc
......@@ -1206,13 +1206,16 @@ class GemmOptimizer(Optimizer):
try:
env.replace_all_validate(
zip(node.outputs, new_outputs),
reason = 'GemmOptimizer')
reason='GemmOptimizer'
)
did_something = True
break
except InconsistencyError, e:
#TODO: retry other applications of gemm (see comment in _gemm_from_node
# TODO: retry other applications of gemm (see comment
# in _gemm_from_node
pass
class Dot22(GemmRelated):
"""Compute a matrix-matrix product.
This is a specialization of the more general Dot()
......@@ -1227,7 +1230,7 @@ class Dot22(GemmRelated):
raise TypeError('dtype mismatch to Dot22')
bz = (x.type.broadcastable[0], y.type.broadcastable[1])
outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y], outputs)
return Apply(self, [x, y], outputs)
def perform(self, node, inp, out):
x, y = inp
......@@ -1235,9 +1238,11 @@ class Dot22(GemmRelated):
try:
z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
# The error raised by numpy has no shape information, we mean to
# add that
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return "_dot22"
......@@ -1250,10 +1255,12 @@ class Dot22(GemmRelated):
npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
type_num_%(_x)s);
//fprintf(stderr, "Dot Allocating %%i %%i\\n", dims[0], dims[1]);
if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
PyErr_SetString(PyExc_MemoryError,
"failed to alloc dot22 output");
%(fail)s
}
}
......@@ -1270,16 +1277,19 @@ class Dot22(GemmRelated):
double a = 1.0;
double b = 0.0;
"""
def c_code(self, node, name, inp, out, sub): #DEBUG
def c_code(self, node, name, inp, out, sub): # DEBUG
_x, _y = inp
_zout, = out
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__)
if len(self.c_libraries())<=0:
return super(Dot22, self).c_code(node, name, (_x, _y), (_zout, ), sub)
if len(self.c_libraries()) <= 0:
return super(Dot22, self).c_code(node, name, (_x, _y),
(_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
def c_code_cache_version(self):
gv = self.build_gemm_version()
if gv:
......@@ -1436,6 +1446,7 @@ optdb.register('InplaceBlasOpt',
blas_opt_inplace,
70.0, 'fast_run', 'inplace')
class Dot22Scalar(GemmRelated):
"""Compute a matrix-matrix product.
This is a specialization of the more general Dot()
......@@ -1473,6 +1484,7 @@ class Dot22Scalar(GemmRelated):
# The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return "_dot22scalar"
......@@ -1492,6 +1504,7 @@ class Dot22Scalar(GemmRelated):
#undef REAL
float b = 0.0;
"""
case_double_ab_constants = """
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
......@@ -1500,16 +1513,18 @@ class Dot22Scalar(GemmRelated):
#undef REAL
double b = 0.0;
"""
def c_code(self, node, name, inp, out, sub): #DEBUG
_x, _y, _a = inp
_zout, = out
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__)
if len(self.c_libraries())<=0:
if len(self.c_libraries()) <= 0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y), (_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
def c_code_cache_version(self):
gv = self.build_gemm_version()
if gv:
......@@ -1558,10 +1573,10 @@ def local_dot22_to_dot22scalar(node):
for i,x in enumerate(m.owner.inputs):
if _as_scalar(x) and (theano.scalar.upcast(x.type.dtype,d.type.dtype)
== d.type.dtype):
scalar_idx=i
scalar_idx = i
break
if scalar_idx<0:
if scalar_idx < 0:
_logger.info('Not optimizing dot22 with inputs %s %s, as the type '
'of the scalar cannot be upcasted to the matrix type',
node.inputs, [x.type for x in node.inputs])
......@@ -1593,7 +1608,7 @@ def local_dot22_to_dot22scalar(node):
== d.type.dtype)):
scalar_idx = i
break
if scalar_idx<0:
if scalar_idx < 0:
_logger.info('Not optimizing dot22 with inputs %s %s, as the type '
'of the scalar cannot be upcasted to the matrix type',
node.inputs, [x.type for x in node.inputs])
......
......@@ -5,7 +5,8 @@ import theano.tensor as T
#from theano.gof import Env
from theano.printing import pp
import numpy, theano
import numpy
import theano
from numpy import (arange, array, common_type, complex64, complex128, float32,
float64, newaxis, shape, transpose, zeros)
from numpy.testing import assert_, assert_array_almost_equal
......@@ -13,10 +14,11 @@ from numpy.testing import assert_, assert_array_almost_equal
#from numpy.testing.noseclasses import KnownFailureTest
#from theano.tensor.blas import *
from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar, _is_real_matrix,
_gemm_canonicalize, _factor_canonicalized, Gemm, Gemv, gemm_inplace, gemm_no_inplace,
InconsistencyError,
Ger, ger, ger_destructive)
from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar,
_is_real_matrix, _gemm_canonicalize,
_factor_canonicalized, Gemm, Gemv,
gemm_inplace, gemm_no_inplace,
InconsistencyError, Ger, ger, ger_destructive)
from unittest import TestCase
from theano.tests import unittest_tools
from copy import copy, deepcopy
......@@ -29,7 +31,8 @@ import theano.tensor.blas_scipy
if config.mode == 'FAST_COMPILE':
mode_not_fast_compile = 'FAST_RUN'
else: mode_not_fast_compile = config.mode
else:
mode_not_fast_compile = config.mode
mode_blas_opt = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
......@@ -675,9 +678,9 @@ def test_inplace1():
def test_dot22():
for dtype1 in ['float32', 'float64', 'complex64', 'complex128']:
a=T.matrix(dtype = dtype1)
a = T.matrix(dtype=dtype1)
for dtype2 in ['float32', 'float64', 'complex64', 'complex128']:
b=T.matrix(dtype = dtype2)
b = T.matrix(dtype=dtype2)
f = theano.function([a,b],T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
if dtype1 == dtype2:
......@@ -691,12 +694,12 @@ def test_dot22():
bv=rng.uniform(size=b_shp).astype(dtype2)
f(av,bv)
cmp((3,4),(4,5))
cmp((0,4),(4,5))
cmp((3,0),(0,5))
cmp((3,4),(4,0))
cmp((0,4),(4,0))
cmp((0,0),(0,0))
cmp((3, 4), (4, 5))
cmp((0, 4), (4, 5))
cmp((3, 0), (0, 5))
cmp((3, 4), (4, 0))
cmp((0, 4), (4, 0))
cmp((0, 0), (0, 0))
def test_dot22scalar():
## including does not seem to work for 'local_dot_to_dot22' and
......@@ -706,11 +709,11 @@ def test_dot22scalar():
#m = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
for dtype1 in ['complex64', 'complex128']:
a=T.matrix('a', dtype = dtype1)
a = T.matrix('a', dtype=dtype1)
for dtype2 in ['complex64', 'complex128']:
b=T.matrix('b', dtype = dtype2)
b = T.matrix('b', dtype=dtype2)
for dtype3 in ['complex64', 'complex128']:
c=T.matrix('c', dtype = dtype3)
c = T.matrix('c', dtype=dtype3)
for dtype4 in ['complex64', 'complex128']:
cst = theano.tensor.basic.constant(.2, dtype=dtype4)
cst2 = theano.tensor.basic.constant(.1, dtype=dtype4)
......@@ -978,11 +981,11 @@ def matrixmultiply(a, b):
class BaseGemv(object):
def get_data(self,x_stride=1,y_stride=1):
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
mult = array(1, dtype = self.dtype)
mult = array(1, dtype=self.dtype)
if self.dtype in [complex64,complex128]:
mult = array(1+1j, dtype = self.dtype)
alpha = array(1., dtype = self.dtype) * mult
beta = array(1., dtype = self.dtype) * mult
mult = array(1+1j, dtype=self.dtype)
alpha = array(1., dtype=self.dtype) * mult
beta = array(1., dtype=self.dtype) * mult
a = rng.randn(3,3).astype(self.dtype) * mult
x = arange(shape(a)[0]*x_stride,dtype=self.dtype) * mult
y = arange(shape(a)[1]*y_stride,dtype=self.dtype) * mult
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论