提交 d632b00a authored 作者: nouiz's avatar nouiz

Merge pull request #386 from lamblin/fix_blas_tests

Fix blas tests
...@@ -47,9 +47,9 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -47,9 +47,9 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ PyErr_SetString(PyExc_TypeError, "A vs. y"); %(fail)s; } { PyErr_SetString(PyExc_TypeError, "A vs. y"); %(fail)s; }
if (%(A)s->dimensions[0] != %(x)s->dimensions[0]) if (%(A)s->dimensions[0] != %(x)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[0] != x.shape[0]"); %(fail)s;} {PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[0] != x.shape[0]"); %(fail)s;}
if (%(A)s->dimensions[1] != %(y)s->dimensions[0]) if (%(A)s->dimensions[1] != %(y)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[1] != y.shape[0]"); %(fail)s;} {PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[1] != y.shape[0]"); %(fail)s;}
if (%(A)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; } if (%(A)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; }
else if (%(A)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;} else if (%(A)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;}
...@@ -198,14 +198,17 @@ class CGer(BaseBLAS, Ger): ...@@ -198,14 +198,17 @@ class CGer(BaseBLAS, Ger):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
@local_optimizer([ger, ger_destructive]) @local_optimizer([ger, ger_destructive])
def use_c_ger(node): def use_c_ger(node):
if node.op == ger: # Only float32 and float64 are supported for now.
if (node.op == ger and
node.outputs[0].dtype in ['float32', 'float64']):
return [CGer(False)(*node.inputs)] return [CGer(False)(*node.inputs)]
if node.op == ger_destructive: if (node.op == ger_destructive and
node.outputs[0].dtype in ['float32', 'float64']):
return [CGer(True)(*node.inputs)] return [CGer(True)(*node.inputs)]
@local_optimizer([CGer(False)]) @local_optimizer([CGer(False)])
...@@ -249,9 +252,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -249,9 +252,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. yy"); %(fail)s; } { PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. yy"); %(fail)s; }
if (%(xx)s->dimensions[0] != %(aa)s->dimensions[0]) if (%(xx)s->dimensions[0] != %(aa)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[0] != x.shape[0]"); %(fail)s;} {PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[0] != x.shape[0]"); %(fail)s;}
if (%(xx)s->dimensions[1] != %(yy)s->dimensions[0]) if (%(xx)s->dimensions[1] != %(yy)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[1] != y.shape[0]"); %(fail)s;} {PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[1] != y.shape[0]"); %(fail)s;}
if (%(aa)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; } if (%(aa)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; }
else if (%(aa)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;} else if (%(aa)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;}
...@@ -420,14 +423,17 @@ class CGemv(BaseBLAS, Gemv): ...@@ -420,14 +423,17 @@ class CGemv(BaseBLAS, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
@local_optimizer([gemv_inplace, gemv_no_inplace]) @local_optimizer([gemv_inplace, gemv_no_inplace])
def use_c_gemv(node): def use_c_gemv(node):
if node.op == gemv_no_inplace: # Only float32 and float64 are supported for now.
if (node.op == gemv_no_inplace and
node.outputs[0].dtype in ['float32', 'float64']):
return [CGemv(inplace=False)(*node.inputs)] return [CGemv(inplace=False)(*node.inputs)]
if node.op == gemv_inplace: if (node.op == gemv_inplace and
node.outputs[0].dtype in ['float32', 'float64']):
return [CGemv(inplace=True)(*node.inputs)] return [CGemv(inplace=True)(*node.inputs)]
......
...@@ -3531,7 +3531,7 @@ class t_dot(unittest.TestCase): ...@@ -3531,7 +3531,7 @@ class t_dot(unittest.TestCase):
self.assertTrue( self.assertTrue(
# Reported by numpy. # Reported by numpy.
e[0].split()[1:4] == ['are', 'not', 'aligned'] or e[0].split()[1:4] == ['are', 'not', 'aligned'] or
# Reported by blas. # Reported by blas or Theano.
e[0].split()[0:2] == ['Shape', 'mismatch:'] or e[0].split()[0:2] == ['Shape', 'mismatch:'] or
# Reported by Theano when 'exception_verbosity' is set # Reported by Theano when 'exception_verbosity' is set
# to 'high'. # to 'high'.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论