提交 a87cca23 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for single-elements itypes and otypes and optional infer_shape callable.

上级 c5b999ed
......@@ -370,17 +370,19 @@ class FromFunctionOp(gof.Op):
Build a basic theano Op around a function.
Since the resulting Op is very basic and is missing most of the
optional functionality, optimizations that rely on shape
information will have trouble with this.
optional functionality, some optimization may not apply. If you
want to help, you can supply an infer_shape function that computes
the shapes of the output given the shapes of the inputs.
Also the gradient is undefined in the resulting op and theano will
raise an error if you attempt to get the gradient of a graph
containing this op.
"""
def __init__(self, fn, itypes, otypes):
def __init__(self, fn, itypes, otypes, infer_shape):
self.__fn = fn
self.itypes = itypes
self.otypes = otypes
self.__infer_shape = infer_shape
def __eq__(self, other):
return (type(self) == type(other) and
......@@ -403,31 +405,55 @@ class FromFunctionOp(gof.Op):
for i in range(len(outs)):
outputs[i][0] = outs[i]
def as_op(itypes, otypes):
def infer_shape(self, node, input_shapes):
if self.__infer_shape:
return self.__infer_shape(node, input_shapes)
else:
# fake method not defined
raise AttributeError('infer_shape')
def as_op(itypes, otypes, infer_shape=None):
"""
Decorator that converts a function into a basic theano op that
will call the supplied function as its implementation.
It takes an optional infer_shape parameter that should be a
callable with this signature:
def infer_shape(node, input_shapes):
...
return output_shapes
Here `input_shapes` and `output_shapes` are lists of tuples that
represent the shape of the corresponding inputs/outputs.
This should not be used when performance is a concern since the
very basic nature of the resulting Op may interfere with certain
graph optimizations.
Example usage:
@as_op(itypes=[theano.tensor.fmatrix(), theano.tensor.fmatrix()],
otypes=[theano.tensor.fmatrix()])
def numpy_dot(a, b):
return numpy.dot(a, b)
"""
if (not isinstance(itypes, (list, tuple)) or
any(not isinstance(t, theano.Type) for t in itypes)):
if not isinstance(itypes, (list, tuple)):
itypes = [itypes]
if any(not isinstance(t, theano.Type) for t in itypes):
raise TypeError("itypes has to be a list of theano types")
if (not isinstance(otypes, (list, tuple)) or
any(not isinstance(t, theano.Type) for t in otypes)):
if not isinstance(otypes, (list, tuple)):
otypes = [otypes]
if any(not isinstance(t, theano.Type) for t in otypes)):
raise TypeError("otypes has to be a list of theano types")
# make sure they are lists and not tuples
itypes = list(itypes)
otypes = list(otypes)
if infer_shape is not None and not callable(infer_shape):
raise TypeError("infer_shape needs to be a callable")
def make_op(fn):
return FromFunctionOp(fn, itypes, otypes)
return FromFunctionOp(fn, itypes, otypes, infer_shape)
return make_op
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论