提交 53e523b0 authored 作者: james@mackie's avatar james@mackie

added fail cases for dot

上级 db400460
...@@ -953,13 +953,8 @@ inv_elemwise_inplace = inv_elemwise.inplace_version() ...@@ -953,13 +953,8 @@ inv_elemwise_inplace = inv_elemwise.inplace_version()
## Dot product ## ## Dot product ##
class dot(omega_op): class dot(omega_op):
@staticmethod
impl = numpy.dot def _output_shape(xshape, yshape):
def grad(x, y, gz):
return dot(gz, transpose(y)), dot(transpose(x), gz)
def specs(x, y):
xshape = x[2]
yshape = y[2]
if len(xshape) == 0: # x is a scalar if len(xshape) == 0: # x is a scalar
shape = yshape shape = yshape
else: else:
...@@ -971,6 +966,13 @@ class dot(omega_op): ...@@ -971,6 +966,13 @@ class dot(omega_op):
shape = tuple(xshape[:-1]) shape = tuple(xshape[:-1])
else: #y is a scalar else: #y is a scalar
shape = xshape shape = xshape
return shape
impl = numpy.dot
def grad(x, y, gz):
return dot(gz, transpose(y)), dot(transpose(x), gz)
def specs(x, y):
shape = dot._output_shape(x[2], y[2])
return (numpy.ndarray, upcast(x[1], y[1]), shape) return (numpy.ndarray, upcast(x[1], y[1]), shape)
def c_support_code(self): def c_support_code(self):
return blas.cblas_header_text() return blas.cblas_header_text()
...@@ -979,56 +981,7 @@ class dot(omega_op): ...@@ -979,56 +981,7 @@ class dot(omega_op):
def c_impl((_x, _y), (_z, )): def c_impl((_x, _y), (_z, )):
return blas.gemm_code('', '1.0', '0.0') return blas.gemm_code('', '1.0', '0.0')
class gemm(omega_op, inplace): class _testCase_dot(unittest.TestCase):
def impl(z, a, x, y, b):
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
return z[:]
def grad(z, a, x, y, b, gz):
raise NotImplemented
def specs(z, a, x, y, b):
return z
def alloc(self, except_list):
self.outputs[0].data = self.inputs[0].data
def c_support_code(self):
return blas.cblas_header_text()
def c_libs(self):
return blas.ldflags()
def c_impl((_zin, _a, _x, _y, _b), (_z,)):
check_ab = """
{
if ((_a->descr->type_num != PyArray_DOUBLE)
&& (_a->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_b->descr->type_num != PyArray_DOUBLE)
&& (_b->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
}
"""
return blas.gemm_code( check_ab,
'(_a->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_a->data)[0]) : (REAL)(((double*)_a->data)[0])',
'(_b->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_b->data)[0]) : (REAL)(((double*)_b->data)[0])')
class _testCase_dotgemm(unittest.TestCase):
def setUp(self): def setUp(self):
build_eval_mode() build_eval_mode()
numpy.random.seed(44) numpy.random.seed(44)
...@@ -1097,6 +1050,140 @@ class _testCase_dotgemm(unittest.TestCase): ...@@ -1097,6 +1050,140 @@ class _testCase_dotgemm(unittest.TestCase):
def test_dot_3d_2d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6,7)) def test_dot_3d_2d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(8,6,7)) def test_dot_3d_3d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(8,6,7))
def test_dot_fail_1_1(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'matrices are not aligned')
return
self.fail()
def test_dot_fail_1_2(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6,4)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'matrices are not aligned')
return
self.fail()
def test_dot_fail_1_3(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6,4,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned')
return
self.fail()
def test_dot_fail_2_1(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'matrices are not aligned')
return
self.fail()
def test_dot_fail_2_2(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'matrices are not aligned')
return
self.fail()
def test_dot_fail_2_3(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6,7,8)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned')
return
self.fail()
def test_dot_fail_3_1(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned')
return
self.fail()
def test_dot_fail_3_2(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned')
return
self.fail()
def test_dot_fail_3_3(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6,7,8)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned')
return
self.fail()
class gemm(omega_op, inplace):
def impl(z, a, x, y, b):
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
return z[:]
def grad(z, a, x, y, b, gz):
raise NotImplemented
def specs(z, a, x, y, b):
assert z[2] == dot._output_shape(x[2], y[2])
return z
def alloc(self, except_list):
self.outputs[0].data = self.inputs[0].data
def c_support_code(self):
return blas.cblas_header_text()
def c_libs(self):
return blas.ldflags()
def c_impl((_zin, _a, _x, _y, _b), (_z,)):
check_ab = """
{
if ((_a->descr->type_num != PyArray_DOUBLE)
&& (_a->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_b->descr->type_num != PyArray_DOUBLE)
&& (_b->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
}
"""
return blas.gemm_code( check_ab,
'(_a->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_a->data)[0]) : (REAL)(((double*)_a->data)[0])',
'(_b->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_b->data)[0]) : (REAL)(((double*)_b->data)[0])')
## Transposition ## ## Transposition ##
class transpose(omega_op, view): class transpose(omega_op, view):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论