dot test cases

上级 2b6275d0
import unittest, os, sys
for filename in os.listdir('.'):
if filename == __file__: continue
#continue
if filename[-3:] == '.py':
modname = filename[:-3]
suite = unittest.TestLoader().loadTestsFromModule(__import__(modname))
#suite.addTests(unittest.TestLoader().loadTestsFromModule(__import__(modname)))
if suite.countTestCases() > 0:
print >>sys.stderr, 'Testing', modname, '(%s)'% (filename),
unittest.TextTestRunner(verbosity=1).run(suite)
......@@ -31,7 +31,7 @@ def experimental_linker(env, target = None):
except NotImplementedError:
result = op._perform
py_ops.add(op)
thunks.append((result, op._perform_like_c))
thunks.append((result, op._perform_inplace))
def ret():
for thunk, fallback in thunks:
......
......@@ -16,6 +16,7 @@ from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode,
import type_spec
import cutils
import blas
import compile
# __all__ = ['set_mode', 'get_mode', 'NumpyR', 'NumpyOp']
......@@ -37,6 +38,14 @@ def as_string(*rs):
def print_graph(*rs):
print as_string(*rs)
def _approx_eq(a,b,eps=1.0e-9):
a = numpy.asarray(a)
b = numpy.asarray(b)
if a.shape != b.shape:
return False
d = abs(a-b)
return numpy.all(d < eps)
literals_db = {}
#literals_id_db = weakref.WeakValueDictionary()
......@@ -680,6 +689,17 @@ class NumpyR(gof.PythonR):
self.refresh()
self.up_to_date = True
def set_value_inplace(self, value):
if value is UNCOMPUTED:
raise ValueError()
else:
if 0 == len(self.data.shape):
self.data.itemset(value)
else:
self.data[:] = value
self.refresh()
self.up_to_date = True
def refresh(self):
if self.data is not UNCOMPUTED:
self.spec = (numpy.ndarray, self.data.dtype, self.data.shape)
......@@ -938,9 +958,19 @@ class dot(omega_op):
def grad(x, y, gz):
return dot(gz, transpose(y)), dot(transpose(x), gz)
def specs(x, y):
# todo: handle all tensors!
assert x[2][1] == y[2][0]
shape = (x[2][0], y[2][1])
xshape = x[2]
yshape = y[2]
if len(xshape) == 0: # x is a scalar
shape = yshape
else:
if len(yshape) >= 2: #y is a matrix or tensor
assert xshape[-1] == yshape[-2]
shape = tuple(xshape[:-1]+ yshape[:-2]+yshape[-1:])
elif len(yshape)==1: #y is vector
assert xshape[-1] == yshape[-1]
shape = tuple(xshape[:-1])
else: #y is a scalar
shape = xshape
return (numpy.ndarray, upcast(x[1], y[1]), shape)
def c_support_code(self):
return blas.cblas_header_text()
......@@ -998,6 +1028,74 @@ class gemm(omega_op, inplace):
'(_a->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_a->data)[0]) : (REAL)(((double*)_a->data)[0])',
'(_b->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_b->data)[0]) : (REAL)(((double*)_b->data)[0])')
class _testCase_dotgemm(unittest.TestCase):
def setUp(self):
build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
pop_mode()
@staticmethod
def rand(*args):
return numpy.random.rand(*args)
def cmp_dot(self,x,y):
def spec(x):
x = numpy.asarray(x)
return type(x), x.dtype, x.shape
zspec = dot.specs(spec(x), spec(y))
nz = numpy.dot(x,y)
self.failUnless(zspec == spec(nz))
self.failUnless(_approx_eq(dot(x,y), numpy.dot(x,y)))
def cmp_dot_comp(self, x,y):
x = numpy.asarray(x)
y = numpy.asarray(y)
z = dot(x,y)
p = compile.single(z)
if len(x.shape):
x[:] = numpy.random.rand(*x.shape)
else:
x.fill(numpy.random.rand(*x.shape))
if len(y.shape):
y[:] = numpy.random.rand(*y.shape)
else:
y.fill(numpy.random.rand(*y.shape))
p() # recalculate z
self.failUnless(_approx_eq(z, numpy.dot(x,y)))
def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2)
def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5))
def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7))
def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5))
def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7))
def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6))
def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7))
def test_dot_0d_0d_(self): self.cmp_dot_comp(1.1, 2.2)
def test_dot_0d_1d_(self): self.cmp_dot_comp(1.1, self.rand(5))
def test_dot_0d_2d_(self): self.cmp_dot_comp(3.0, self.rand(6,7))
def test_dot_0d_3d_(self): self.cmp_dot_comp(3.0, self.rand(8,6,7))
def test_dot_1d_0d_(self): self.cmp_dot_comp(self.rand(5), 1.1 )
def test_dot_1d_1d_(self): self.cmp_dot_comp(self.rand(5), self.rand(5))
def test_dot_1d_2d_(self): self.cmp_dot_comp(self.rand(6), self.rand(6,7))
def test_dot_1d_3d_(self): self.cmp_dot_comp(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d_(self): self.cmp_dot_comp(self.rand(5,6), 1.0)
def test_dot_2d_1d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6))
def test_dot_2d_2d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d_(self): self.cmp_dot_comp(self.rand(4,5,6), 1.0)
def test_dot_3d_1d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(8,6,7))
## Transposition ##
......
......@@ -102,6 +102,9 @@ class PythonR(Result):
self.up_to_date = True
self.refresh()
def set_value_inplace(self, value):
raise NotImplementedError()
def __str__(self):
return str(self.data)
......@@ -231,14 +234,14 @@ class PythonOp(Op):
for result, output in zip(results, self.outputs):
output.set_value(result)
def _perform_like_c(self):
def _perform_inplace(self):
results = self._impl()
if self.nout == 1:
self.outputs[0].data[:] = results
self.out.set_value_inplace(results)
else:
assert self.nout == len(results)
for result, output in zip(results, self.outputs):
output.data[:] = result
output.set_value_inplace(result)
def _impl(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论