提交 85aede29 authored 作者: Frederic's avatar Frederic

copied the old gpueye code and test.

上级 ecbb2692
...@@ -559,3 +559,95 @@ class GpuReshape(HideC, tensor.Reshape): ...@@ -559,3 +559,95 @@ class GpuReshape(HideC, tensor.Reshape):
else: else:
raise ValueError("total size of new array must be unchanged") raise ValueError("total size of new array must be unchanged")
out[0] = x.reshape(tuple(shp)) out[0] = x.reshape(tuple(shp))
class GpuEye(GpuOp):
def __init__(self, dtype=None):
if dtype is None:
dtype = config.floatX
assert dtype == 'float32'
self.dtype = dtype
def make_node(self, n, m, k):
n = tensor.as_tensor_variable(n)
m = tensor.as_tensor_variable(m)
k = tensor.as_tensor_variable(k)
assert n.ndim == 0
assert m.ndim == 0
assert k.ndim == 0
# k != 0 isn't implemented on the GPU yet.
assert tensor.get_scalar_constant_value(k) == 0
return Apply(self, [n, m], [matrix(dtype=self.dtype)])
def infer_shape(self, node, in_shapes):
out_shape = [node.inputs[0], node.inputs[1]]
return [out_shape]
def grad(self, inp, grads):
return [grad_undefined(self, i, inp[i]) for i in xrange(3)]
def __eq__(self, other):
return type(self) == type(other) and self.dtype == other.dtype
def __hash__(self):
return hash(self.dtype) ^ hash(type(self))
def c_support_code(self):
return """
//Only 1 block is used.
__global__ void kEye(float* a, int n, int m) {
int nb_elem = min(n, m);
for (unsigned int i = threadIdx.x; i < nb_elem; i += blockDim.x) {
a[i*m + i] = 1;
}
}"""
def c_code(self, node, name, inp, out, sub):
n, m = inp
z, = out
fail = sub['fail']
s = """
int dims[] = {0, 0};
dims[0] = ((dtype_%(n)s*)PyArray_DATA(%(n)s))[0];
dims[1] = ((dtype_%(m)s*)PyArray_DATA(%(m)s))[0];
int total_size = dims[0] * dims[1] * sizeof(float);
cudaError_t sts;
void * orig_z = %(z)s;
if (CudaNdarray_prep_output(&%(z)s, 2, dims))
{
%(fail)s;
}
sts = cudaMemset(CudaNdarray_DEV_DATA(%(z)s), 0, total_size);
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_MemoryError,
"GpuEye: Error in memset %%d bytes of device memory.",
total_size);
if(orig_z == NULL)
Py_XDECREF(%(z)s);
%(fail)s;
}
kEye<<<1, 256>>>(CudaNdarray_DEV_DATA(%(z)s), dims[0], dims[1]);
CNDA_THREAD_SYNC;
sts = cudaGetLastError();
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_RuntimeError,
"Cuda error: kEye: %%s. n=%%d, m=%%d.",
cudaGetErrorString(sts),
dims[0], dims[1]);
%(fail)s;
}
""" % locals()
return s
def c_code_cache_version(self):
return (3,)
gpu_eye = GpuEye(dtype='float32')
...@@ -306,3 +306,32 @@ class G_reshape(T_reshape): ...@@ -306,3 +306,32 @@ class G_reshape(T_reshape):
theano.tensor.opt.Shape_i, theano.tensor.opt.Shape_i,
theano.tensor.opt.MakeVector)) theano.tensor.opt.MakeVector))
assert self.op == GpuReshape assert self.op == GpuReshape
def test_gpueye():
def check(dtype, N, M_=None):
# Theano does not accept None as a tensor.
# So we must use a real value.
M = M_
# Currently DebugMode does not support None as inputs even if this is
# allowed.
if M is None:
M = N
N_symb = T.iscalar()
M_symb = T.iscalar()
k_symb = numpy.asarray(0)
out = T.eye(N_symb, M_symb, k_symb, dtype=dtype)
f = theano.function([N_symb, M_symb],
B.as_cuda_ndarray_variable(out),
mode=mode_with_gpu)
result = numpy.asarray(f(N, M))
assert numpy.allclose(result, numpy.eye(N, M_, dtype=dtype))
assert result.dtype == numpy.dtype(dtype)
assert any([isinstance(node.op, B.GpuEye)
for node in f.maker.fgraph.toposort()])
for dtype in ['float32']:
yield check, dtype, 3
# M != N, k = 0
yield check, dtype, 3, 5
yield check, dtype, 5, 3
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论