提交 bb5a1127 authored 作者: Frederic's avatar Frederic

Make a register for SpecifyShape.c_code

上级 c3d3ea87
......@@ -4,7 +4,7 @@ from theano.compile.ops import (
Shape_i, register_shape_i_c_code,
ViewOp, view_op, register_view_op_c_code, FromFunctionOp,
as_op, Rebroadcast, register_rebroadcast_c_code,
SpecifyShape, specify_shape)
SpecifyShape, specify_shape, register_specify_shape_c_code)
from theano.compile.function_module import *
......
......@@ -613,6 +613,20 @@ class Rebroadcast(gof.Op):
return tuple(version)
def register_specify_shape_c_code(typ, code, version=()):
""" Tell SpecifyShape how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an
instance of the class.
:param code: C code that deep copies the Theano type 'typ'.
Use %(iname)s and %(oname)s for the input and output C
variable names respectively.
%(axis)s for the axis that we need to check.
:param version: A number indicating the version of the code, for cache.
"""
SpecifyShape.c_code_and_version[typ] = (code, version)
class SpecifyShape(gof.Op):
"""
L{Op} that puts into the graph the user-provided shape.
......@@ -629,6 +643,10 @@ class SpecifyShape(gof.Op):
Do C code for them too.
"""
view_map = {0: [0]}
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
def __hash__(self):
return hash(type(self))
......@@ -692,45 +710,33 @@ class SpecifyShape(gof.Op):
return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs
def c_code(self, node, nodename, inp, out, sub):
if not isinstance(node.inputs[0], theano.tensor.TensorVariable):
# The C code below supports only Tensor. super.c_code
# will raise an exception to tell that there is no C code
# for the other cases.
return super(SpecifyShape, self).c_code(node, nodename,
inp, out, sub)
iname, shape = inp
oname, = out
def c_code(self, node, name, inames, onames, sub):
iname, shape = inames
oname, = onames
fail = sub['fail']
return """
if (PyArray_NDIM(%(iname)s) != PyArray_DIMS(%(shape)s)[0]) {
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: vector of shape has %%d elements,"
" but the input has %%d dimensions.",
PyArray_NDIM(%(iname)s),
PyArray_DIMS(%(shape)s)[0]);
%(fail)s;
}
for(int i = 0; i < PyArray_NDIM(%(iname)s); i++){
dtype_%(shape)s shp = ((dtype_%(shape)s*)PyArray_GETPTR1(%(shape)s,
i))[0];
if (PyArray_DIMS(%(iname)s)[i] != shp) {
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: dim %%d of input has shape %%d,"
" expected %%d.",
i, PyArray_DIMS(%(iname)s)[i],
shp);
%(fail)s;
}
}
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""" % locals()
itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version:
code, version = self.c_code_and_version[itype]
return code % locals()
return super(SpecifyShape, self).c_code(node, node, inames, onames, sub)
def c_code_cache_version(self):
return (1,)
version = []
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])):
if not v:
warnings.warn("Type %s has C code for SpecifyShape, but it has "
"no version. You should add a 'version' keyword arg "
"when calling register_specify_shape_c_code." % t,
stacklevel=2)
return ()
version.append((str(t), v))
return tuple(version)
specify_shape = SpecifyShape()
......@@ -673,7 +673,6 @@ theano.compile.register_deep_copy_op_c_code(
version=2)
# Register TensorType C code for ViewOp.
theano.compile.register_rebroadcast_c_code(
TensorType,
"""
......@@ -686,3 +685,33 @@ theano.compile.register_rebroadcast_c_code(
}
""",
version=1)
theano.compile.register_specify_shape_c_code(
TensorType,
"""
if (PyArray_NDIM(%(iname)s) != PyArray_DIMS(%(shape)s)[0]) {
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: vector of shape has %%d elements,"
" but the input has %%d dimensions.",
PyArray_NDIM(%(iname)s),
PyArray_DIMS(%(shape)s)[0]);
%(fail)s;
}
for(int i = 0; i < PyArray_NDIM(%(iname)s); i++){
dtype_%(shape)s shp = ((dtype_%(shape)s*)PyArray_GETPTR1(%(shape)s,
i))[0];
if (PyArray_DIMS(%(iname)s)[i] != shp) {
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: dim %%d of input has shape %%d,"
" expected %%d.",
i, PyArray_DIMS(%(iname)s)[i],
shp);
%(fail)s;
}
}
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""",
version=1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论