提交 304c9e03 authored 作者: Frederic's avatar Frederic

Move Rebroadcast() to compile/ops.py and make a register system for the c code.

上级 26d91309
......@@ -3,7 +3,7 @@ from theano.compile.ops import (
Shape, shape, register_shape_c_code,
Shape_i, register_shape_i_c_code,
ViewOp, view_op, register_view_op_c_code, FromFunctionOp,
as_op)
as_op, Rebroadcast, register_rebroadcast_c_code)
from theano.compile.function_module import *
......
......@@ -353,7 +353,7 @@ class Shape_i(gof.Op):
def register_shape_i_c_code(typ, code, version=()):
""" Tell DeepCopyOp how to generate C code for a Theano Type
""" Tell Shape_i how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an
instance of the class.
......@@ -461,3 +461,138 @@ def as_op(itypes, otypes, infer_shape=None):
def make_op(fn):
return FromFunctionOp(fn, itypes, otypes, infer_shape)
return make_op
def register_rebroadcast_c_code(typ, code, version=()):
""" Tell Rebroadcast how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an
instance of the class.
:param code: C code that deep copies the Theano type 'typ'.
Use %(iname)s and %(oname)s for the input and output C
variable names respectively.
%(axis)s for the axis that we need to check.
:param version: A number indicating the version of the code, for cache.
"""
Rebroadcast.c_code_and_version[typ] = (code, version)
class Rebroadcast(gof.Op):
"""
Change the input's broadcastable fields in
some predetermined way.
e.g.: Rebroadcast((0, True), (1, False))(x)
would make x broadcastable in axis 0
and not broadcastable in axis 1
See also the unbroadcast, addbroadcast and patternbroadcast functions.
..note: work inplace and work for CudaNdarrayType
"""
view_map = {0: [0]}
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
def __init__(self, *axis):
self.axis = dict(axis)
for axis, broad in self.axis.iteritems():
assert isinstance(axis, (numpy.integer, int)), (
"Rebroadcast need integers axis. Got ", axis)
def __eq__(self, other):
return type(self) == type(other) and self.axis == other.axis
def __hash__(self):
items = self.axis.items()
items.sort() # no ambiguity because each item key is unique
return hash(type(self)) ^ hash(tuple(items))
def __str__(self):
if len(self.axis) == 0:
broadcast_pattern = []
else:
broadcast_pattern = ['?' for i
in xrange(1 + numpy.max(self.axis.keys()))]
for k, v in self.axis.iteritems():
broadcast_pattern[k] = str(int(v))
return '%s{%s}' % (self.__class__.__name__,
','.join(broadcast_pattern))
def make_node(self, x):
if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())):
raise ValueError('Trying to rebroadcast nonexistant dimension')
t = x.type.__class__(dtype=x.type.dtype,
broadcastable=[self.axis.get(i, b)
for i, b in enumerate(
x.type.broadcastable)])
return gof.Apply(self, [x], [t()])
def perform(self, node, inp, out_):
x, = inp
out, = out_
for axis, value in self.axis.iteritems():
if value and x.shape[axis] != 1:
raise ValueError('Dimension %s in Rebroadcast\'s input was'
' supposed to be 1 (got %s instead)' %
(axis, x.shape[axis]))
out[0] = x
def grad(self, inp, grads):
x, = inp
gz, = grads
# restore the broadcasting pattern of the input
return Rebroadcast(*[(axis, x.type.broadcastable[axis])
for axis, value in self.axis.iteritems()])(gz),
def infer_shape(self, node, ishapes):
assert len(ishapes) == 1
l = []
one = constant(1)
for ax in xrange(len(ishapes[0])):
if self.axis.get(ax, False):
l.append(one)
else:
l.append(ishapes[0][ax])
return [tuple(l)]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self(*eval_points, **dict(return_list=True))
def c_code(self, node, nodename, inp, out, sub):
iname, = inp
oname, = out
fail = sub['fail']
itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version:
code, version = self.c_code_and_version[itype]
final_code = ""
for axis, value in self.axis.iteritems():
if value:
final_code += code % locals()
return final_code + """
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""" % locals()
return super(Rebroadcast, self).c_code(node, nodename, inp, out, sub)
def c_code_cache_version(self):
version = []
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])):
if not v:
warnings.warn("Type %s has C code for Rebroadcast, but it has "
"no version. You should add a 'version' keyword arg "
"when calling register_rebroadcast_c_code." % t,
stacklevel=2)
return ()
version.append((str(t), v))
return tuple(version)
......@@ -25,7 +25,7 @@ from theano.gof.python25 import partial, any, all
from theano.gof.utils import hashtype
from theano import compile, printing
from theano.printing import pprint, min_informative_str
from theano.compile import Shape, shape #For history
from theano.compile import Rebroadcast, Shape, shape #For history
# We use these exceptions as well.
......@@ -3326,119 +3326,6 @@ class Split(Op):
return self.make_node(eval_points[0], *inputs[1:]).outputs
class Rebroadcast(Op):
"""
Change the input's broadcastable fields in
some predetermined way.
e.g.: Rebroadcast((0, True), (1, False))(x)
would make x broadcastable in axis 0
and not broadcastable in axis 1
See also the unbroadcast, addbroadcast and patternbroadcast functions.
..note: work inplace and work for CudaNdarrayType
"""
view_map = {0: [0]}
def __init__(self, *axis):
self.axis = dict(axis)
for axis, broad in self.axis.iteritems():
assert isinstance(axis, (numpy.integer, int)), (
"Rebroadcast need integers axis. Got ", axis)
def __eq__(self, other):
return type(self) == type(other) and self.axis == other.axis
def __hash__(self):
items = self.axis.items()
items.sort() # no ambiguity because each item key is unique
return hash(type(self)) ^ hash(tuple(items))
def __str__(self):
if len(self.axis) == 0:
broadcast_pattern = []
else:
broadcast_pattern = ['?' for i
in xrange(1 + numpy.max(self.axis.keys()))]
for k, v in self.axis.iteritems():
broadcast_pattern[k] = str(int(v))
return '%s{%s}' % (self.__class__.__name__,
','.join(broadcast_pattern))
def make_node(self, x):
if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())):
raise ValueError('Trying to rebroadcast nonexistant dimension')
t = x.type.__class__(dtype=x.type.dtype,
broadcastable=[self.axis.get(i, b)
for i, b in enumerate(
x.type.broadcastable)])
return Apply(self, [x], [t()])
def perform(self, node, inp, out_):
x, = inp
out, = out_
for axis, value in self.axis.iteritems():
if value and x.shape[axis] != 1:
raise ValueError('Dimension %s in Rebroadcast\'s input was'
' supposed to be 1 (got %s instead)' %
(axis, x.shape[axis]))
out[0] = x
def grad(self, inp, grads):
x, = inp
gz, = grads
# restore the broadcasting pattern of the input
return Rebroadcast(*[(axis, x.type.broadcastable[axis])
for axis, value in self.axis.iteritems()])(gz),
def infer_shape(self, node, ishapes):
assert len(ishapes) == 1
l = []
one = constant(1)
for ax in xrange(len(ishapes[0])):
if self.axis.get(ax, False):
l.append(one)
else:
l.append(ishapes[0][ax])
return [tuple(l)]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self(*eval_points, **dict(return_list=True))
def c_code(self, node, nodename, inp, out, sub):
iname, = inp
oname, = out
fail = sub['fail']
if isinstance(node.inputs[0].type, TensorType):
code = ""
for axis, value in self.axis.iteritems():
if value:
code += """
if(PyArray_DIMS(%(iname)s)[%(axis)s] != 1){
PyErr_Format(PyExc_ValueError,
"Dimension %(axis)s in Rebroadcast's input was"
" supposed to be 1 (got %%d instead)",
PyArray_DIMS(%(iname)s)[%(axis)s]);
%(fail)s
}
""" % locals()
return code + """
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""" % locals()
else:
#TODO: if your type is not listed here, make a damn registry of
# shape_i ops for various types of variables.
# Do not continue this madness.
return super(Rebroadcast, self).c_code(node, nodename, inp, out, sub)
def c_code_cache_version(self):
return (1,)
def addbroadcast(x, *axes):
"""
Make the input broadcastable in the specified axes.
......
......@@ -671,3 +671,18 @@ theano.compile.register_deep_copy_op_c_code(
}
""",
version=2)
# Register TensorType C code for ViewOp.
theano.compile.register_rebroadcast_c_code(
TensorType,
"""
if(PyArray_DIMS(%(iname)s)[%(axis)s] != 1){
PyErr_Format(PyExc_ValueError,
"Dimension %(axis)s in Rebroadcast's input was"
" supposed to be 1 (got %%d instead)",
PyArray_DIMS(%(iname)s)[%(axis)s]);
%(fail)s
}
""",
version=1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论