added gemm, default_input_scalar_dtype, changed NumpyR.set_value

上级 b84f26c0
...@@ -36,11 +36,20 @@ def print_graph(*rs): ...@@ -36,11 +36,20 @@ def print_graph(*rs):
literals_db = {} literals_db = {}
literals_id_db = weakref.WeakValueDictionary() literals_id_db = weakref.WeakValueDictionary()
#input floating point scalars will be cast to arrays of this type
default_input_scalar_dtype = 'float64'
def input(x): def input(x):
#NB:
# - automatically casting int to float seems wrong.
# - we want to be able to write y = x + 1 and maybe have the 1 casted to 1.0
# at some point to maximize speed right?
# - But more important is the ability to store index values without them
# being cast to floating-point (can that cause incorrectness?)
if isinstance(x, numpy.ndarray): if isinstance(x, numpy.ndarray):
return NumpyR(x) return NumpyR(x)
elif isinstance(x, (int, float)): elif isinstance(x, (int, float)):
z = numpy.zeros((), dtype = 'float32') z = numpy.zeros((), dtype = default_input_scalar_dtype)
z += x z += x
return NumpyR(z) return NumpyR(z)
elif isinstance(x, gof.Result): elif isinstance(x, gof.Result):
...@@ -593,17 +602,31 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None): ...@@ -593,17 +602,31 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
class NumpyR(gof.PythonR): class NumpyR(gof.PythonR):
"""The class for storing ndarray return values from omega ops.
The class provides additional functionality compared to the normal PythonR:
- operator overloads that correspond to omega ops such as add() and scale()
- special attributes that make it behave like an ndarray when passed to
numpy functions.
Attributes:
__array__ - alias of self.data.__array_struct__
__array_struct__ - alias of self.data.__array_struct__
Methods:
set_value() -
"""
# The following attributes make NumpyR instances look like normal ndarray
# instances to many numpy functions, such as argmax(), dot(), svd(), sum(),
# etc. These are documented in the numpy book.
__array__ = property(lambda self: self.data.__array__ )
__array_struct__ = property(lambda self: self.data.__array_struct__ )
def set_value(self, value): def set_value(self, value):
assert value is not None if value is UNCOMPUTED:
if value is None or value is UNCOMPUTED:
self.data = UNCOMPUTED self.data = UNCOMPUTED
elif isinstance(value, numpy.ndarray):
self.data = value
elif isinstance(value, PythonR):
self.set_value(value.data)
else: else:
self.data = numpy.array(value) self.data = numpy.asarray(value)
self.refresh() self.refresh()
self.up_to_date = True self.up_to_date = True
...@@ -886,7 +909,6 @@ class dot(omega_op): ...@@ -886,7 +909,6 @@ class dot(omega_op):
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1])) if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{ {
PyErr_SetString(PyExc_ValueError, "mat_gemm input array size mismatch");
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n"); fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
return 1; return 1;
} }
...@@ -937,6 +959,115 @@ class dot(omega_op): ...@@ -937,6 +959,115 @@ class dot(omega_op):
/* v 1 */ /* v 1 */
""" % dict(dtype = '_x_dtype', gemm = gemm) """ % dict(dtype = '_x_dtype', gemm = gemm)
class gemm(omega_op, inplace):
def impl(z, a, x, y, b):
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
return z[:]
def grad(z, a, x, y, b, gz):
raise NotImplemented
def specs(z, a, x, y, b):
return z
def alloc(self, except_list):
self.outputs[0].data = self.inputs[0].data
def c_headers(self):
return ["<gsl/gsl_cblas.h>"]
def c_libs(self):
return ["cblas", "atlas", "g2c"]
def c_impl((_z, _a, _x, _y, _b), (_zout,)):
dtype = _x.spec[1]
if dtype.char == 'f':
cblas_gemm = 'cblas_sgemm'
elif dtype.char == 'd':
cblas_gemm = 'cblas_dgemm'
else:
raise NotImplementedError
return """
%(dtype)s a = ((%(dtype)s*)PyArray_DATA(_a))[0];
%(dtype)s b = ((%(dtype)s*)PyArray_DATA(_b))[0];
%(dtype)s* x = (%(dtype)s*)PyArray_DATA(_x);
%(dtype)s* y = (%(dtype)s*)PyArray_DATA(_y);
%(dtype)s* z = (%(dtype)s*)PyArray_DATA(_z);
npy_intp* Nx = _x->dimensions;
npy_intp* Ny = _y->dimensions;
npy_intp* Nz = _z->dimensions;
npy_intp* Sx = _x->strides;
npy_intp* Sy = _y->strides;
npy_intp* Sz = _z->strides;
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
return 1;
}
if ((Sx[0] < 1) || (Sx[1] < 1)
|| (Sy[0] < 1) || (Sy[1] < 1)
|| (Sz[0] < 1) || (Sz[1] < 1))
{
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
return 1;
//return mat_gemm_general(a, A, B, b, C);
}
//TODO: OPTIMIZE for many special cases:
//- gemv
//- ger
//- ddot
//- others?
int unit = 0;
unit |= ((Sx[1] == sizeof(%(dtype)s)) ? 0x0 : (Sx[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == sizeof(%(dtype)s)) ? 0x0 : (Sy[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == sizeof(%(dtype)s)) ? 0x0 : (Sz[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 8;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
size_t sx_0 = (Nx[0] > 1) ? Sx[0]/sizeof(%(dtype)s) : Nx[1];
size_t sx_1 = (Nx[1] > 1) ? Sx[1]/sizeof(%(dtype)s) : Nx[0];
size_t sy_0 = (Ny[0] > 1) ? Sy[0]/sizeof(%(dtype)s) : Ny[1];
size_t sy_1 = (Ny[1] > 1) ? Sy[1]/sizeof(%(dtype)s) : Ny[0];
size_t sz_0 = (Nz[0] > 1) ? Sz[0]/sizeof(%(dtype)s) : Nz[1];
size_t sz_1 = (Nz[1] > 1) ? Sz[1]/sizeof(%(dtype)s) : Nz[0];
switch(unit)
{
case 0x000: %(cblas_gemm)s(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: %(cblas_gemm)s(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: %(cblas_gemm)s(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: %(cblas_gemm)s(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: %(cblas_gemm)s(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: %(cblas_gemm)s(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: %(cblas_gemm)s(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: %(cblas_gemm)s(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default:
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
return 1;
};
/* v 1 */
""" % dict(dtype = '_x_dtype', cblas_gemm = cblas_gemm)
## Transposition ## ## Transposition ##
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论