提交 8fdb9e26 authored 作者: Frederic's avatar Frederic

Make SpecifyShape work on the GPU and tests.

上级 bb5a1127
...@@ -613,7 +613,8 @@ class Rebroadcast(gof.Op): ...@@ -613,7 +613,8 @@ class Rebroadcast(gof.Op):
return tuple(version) return tuple(version)
def register_specify_shape_c_code(typ, code, version=()): def register_specify_shape_c_code(typ, code, version=(),
c_support_code_apply=None):
""" Tell SpecifyShape how to generate C code for a Theano Type """ Tell SpecifyShape how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an :param typ: A Theano type. It must be the Theano class itself and not an
...@@ -623,8 +624,9 @@ def register_specify_shape_c_code(typ, code, version=()): ...@@ -623,8 +624,9 @@ def register_specify_shape_c_code(typ, code, version=()):
variable names respectively. variable names respectively.
%(axis)s for the axis that we need to check. %(axis)s for the axis that we need to check.
:param version: A number indicating the version of the code, for cache. :param version: A number indicating the version of the code, for cache.
:param c_support_code_apply: extra code.
""" """
SpecifyShape.c_code_and_version[typ] = (code, version) SpecifyShape.c_code_and_version[typ] = (code, version, c_support_code_apply)
class SpecifyShape(gof.Op): class SpecifyShape(gof.Op):
...@@ -710,6 +712,14 @@ class SpecifyShape(gof.Op): ...@@ -710,6 +712,14 @@ class SpecifyShape(gof.Op):
return [None] return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs return self.make_node(eval_points[0], *inputs[1:]).outputs
def c_support_code_apply(self, node, name):
itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version:
_, _, support_code = self.c_code_and_version[itype]
if support_code:
return support_code
return super(SpecifyShape, self).c_support_code_apply(node, name)
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
iname, shape = inames iname, shape = inames
oname, = onames oname, = onames
...@@ -717,7 +727,7 @@ class SpecifyShape(gof.Op): ...@@ -717,7 +727,7 @@ class SpecifyShape(gof.Op):
itype = node.inputs[0].type.__class__ itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version: if itype in self.c_code_and_version:
code, version = self.c_code_and_version[itype] code, version, _ = self.c_code_and_version[itype]
return code % locals() return code % locals()
return super(SpecifyShape, self).c_code(node, node, inames, onames, sub) return super(SpecifyShape, self).c_code(node, node, inames, onames, sub)
...@@ -726,7 +736,7 @@ class SpecifyShape(gof.Op): ...@@ -726,7 +736,7 @@ class SpecifyShape(gof.Op):
version = [] version = []
# If any of the c code is unversionned, we have to return () # If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs. # Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), for t, (c, v, _) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])): key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for SpecifyShape, but it has " warnings.warn("Type %s has C code for SpecifyShape, but it has "
......
...@@ -282,7 +282,10 @@ def local_gpua_dimshuffle(node): ...@@ -282,7 +282,10 @@ def local_gpua_dimshuffle(node):
@register_opt() @register_opt()
@op_lifter([tensor.SpecifyShape]) @op_lifter([tensor.SpecifyShape])
def local_gpua_specifyShape(node): def local_gpua_specifyShape(node):
return tensor.specify_shape if isinstance(node.inputs[0].type, GpuArrayType):
return
inp = [gpu_from_host(node.inputs[0])] + node.inputs[1:]
return tensor.specify_shape(*inp)
@register_opt() @register_opt()
......
...@@ -11,6 +11,8 @@ from theano.sandbox.gpuarray.tests.test_basic_ops import ( ...@@ -11,6 +11,8 @@ from theano.sandbox.gpuarray.tests.test_basic_ops import (
rand_gpuarray, mode_with_gpu, mode_without_gpu rand_gpuarray, mode_with_gpu, mode_without_gpu
) )
from theano.tests.unittest_tools import SkipTest from theano.tests.unittest_tools import SkipTest
from theano.tensor.tests.test_basic import TestSpecifyShape
def test_flatten(): def test_flatten():
m = theano.tensor.fmatrix() m = theano.tensor.fmatrix()
...@@ -108,3 +110,9 @@ def test_rebroadcast(): ...@@ -108,3 +110,9 @@ def test_rebroadcast():
assert isinstance(rebr.inputs[0].type, GpuArrayType) assert isinstance(rebr.inputs[0].type, GpuArrayType)
assert isinstance(rebr.outputs[0].type, GpuArrayType) assert isinstance(rebr.outputs[0].type, GpuArrayType)
class TestSpecifyShape(TestSpecifyShape):
mode = mode_with_gpu
input_type = GpuArrayType
pass
...@@ -339,4 +339,34 @@ theano.compile.register_rebroadcast_c_code( ...@@ -339,4 +339,34 @@ theano.compile.register_rebroadcast_c_code(
%(fail)s %(fail)s
} }
""", """,
version=1) version=1)
theano.compile.register_specify_shape_c_code(
GpuArrayType,
"""
if (PyGpuArray_NDIM(%(iname)s) != PyArray_DIMS(%(shape)s)[0]) {
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: vector of shape has %%d elements,"
" but the input has %%d dimensions.",
PyGpuArray_NDIM(%(iname)s),
PyArray_DIMS(%(shape)s)[0]);
%(fail)s;
}
for(int i = 0; i < PyGpuArray_NDIM(%(iname)s); i++){
dtype_%(shape)s shp = ((dtype_%(shape)s*)PyArray_GETPTR1(%(shape)s,
i))[0];
if (PyGpuArray_DIMS(%(iname)s)[i] != shp) {
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: dim %%d of input has shape %%d,"
" expected %%d.",
i, PyGpuArray_DIMS(%(iname)s)[i],
shp);
%(fail)s;
}
}
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""",
version=1,
c_support_code_apply='#include <numpy_compat.h>')
...@@ -6178,7 +6178,11 @@ def test_stacklists(): ...@@ -6178,7 +6178,11 @@ def test_stacklists():
x = numpy.ones((4, 4), 'float32') x = numpy.ones((4, 4), 'float32')
assert f(x,x,x,x).shape == (2, 2, 4, 4) assert f(x,x,x,x).shape == (2, 2, 4, 4)
class TestSpecifyShape(unittest.TestCase): class TestSpecifyShape(unittest.TestCase):
mode = None
input_type = TensorType
def shortDescription(self): def shortDescription(self):
return None return None
...@@ -6189,14 +6193,21 @@ class TestSpecifyShape(unittest.TestCase): ...@@ -6189,14 +6193,21 @@ class TestSpecifyShape(unittest.TestCase):
x = vector() x = vector()
xval = numpy.random.rand(2).astype(floatX) xval = numpy.random.rand(2).astype(floatX)
f = theano.function([x], specify_shape(x, [2])) f = theano.function([x], specify_shape(x, [2]), mode=self.mode)
f(xval) f(xval)
xval = numpy.random.rand(3).astype(floatX) xval = numpy.random.rand(3).astype(floatX)
self.assertRaises(AssertionError, f, xval) self.assertRaises(AssertionError, f, xval)
theano.printing.debugprint(f)
assert isinstance([n for n in f.maker.fgraph.toposort()
if isinstance(n.op, SpecifyShape)][0].inputs[0].type,
self.input_type)
x = matrix() x = matrix()
xval = numpy.random.rand(2, 3).astype(floatX) xval = numpy.random.rand(2, 3).astype(floatX)
f = theano.function([x], specify_shape(x, [2, 3])) f = theano.function([x], specify_shape(x, [2, 3]), mode=self.mode)
assert isinstance([n for n in f.maker.fgraph.toposort()
if isinstance(n.op, SpecifyShape)][0].inputs[0].type,
self.input_type)
f(xval) f(xval)
for shape in [(1, 3), (2, 2), (5, 5)]: for shape in [(1, 3), (2, 2), (5, 5)]:
xval = numpy.random.rand(*shape).astype(floatX) xval = numpy.random.rand(*shape).astype(floatX)
...@@ -6212,7 +6223,11 @@ class TestSpecifyShape(unittest.TestCase): ...@@ -6212,7 +6223,11 @@ class TestSpecifyShape(unittest.TestCase):
self.assertRaises(AssertionError, specify_shape, x, []) self.assertRaises(AssertionError, specify_shape, x, [])
self.assertRaises(AssertionError, specify_shape, x, [2, 2]) self.assertRaises(AssertionError, specify_shape, x, [2, 2])
f = theano.function([x, shape_vec], specify_shape(x, shape_vec)) f = theano.function([x, shape_vec], specify_shape(x, shape_vec),
mode=self.mode)
assert isinstance([n for n in f.maker.fgraph.toposort()
if isinstance(n.op, SpecifyShape)][0].inputs[0].type,
self.input_type)
self.assertRaises(AssertionError, f, xval, []) self.assertRaises(AssertionError, f, xval, [])
self.assertRaises(AssertionError, f, xval, [2, 2]) self.assertRaises(AssertionError, f, xval, [2, 2])
...@@ -6222,7 +6237,11 @@ class TestSpecifyShape(unittest.TestCase): ...@@ -6222,7 +6237,11 @@ class TestSpecifyShape(unittest.TestCase):
(1,), (1,),
(2, 3, 4)]: (2, 3, 4)]:
self.assertRaises(AssertionError, specify_shape, x, shape) self.assertRaises(AssertionError, specify_shape, x, shape)
f = theano.function([x, shape_vec], specify_shape(x, shape_vec)) f = theano.function([x, shape_vec], specify_shape(x, shape_vec),
mode=self.mode)
assert isinstance([n for n in f.maker.fgraph.toposort()
if isinstance(n.op, SpecifyShape)][0].inputs[0].type,
self.input_type)
self.assertRaises(AssertionError, f, xval, shape) self.assertRaises(AssertionError, f, xval, shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论