提交 c5b999ed authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add an op that wrap a python function and uses it as its perform.

上级 19c0dbcb
...@@ -364,3 +364,70 @@ def register_shape_i_c_code(typ, code, version=()): ...@@ -364,3 +364,70 @@ def register_shape_i_c_code(typ, code, version=()):
# List of Theano Types that one can add an extra dimension and for which # List of Theano Types that one can add an extra dimension and for which
# Scan can deal with. # Scan can deal with.
expandable_types = () expandable_types = ()
class FromFunctionOp(gof.Op):
"""
Build a basic theano Op around a function.
Since the resulting Op is very basic and is missing most of the
optional functionality, optimizations that rely on shape
information will have trouble with this.
Also the gradient is undefined in the resulting op and theano will
raise an error if you attempt to get the gradient of a graph
containing this op.
"""
def __init__(self, fn, itypes, otypes):
self.__fn = fn
self.itypes = itypes
self.otypes = otypes
def __eq__(self, other):
return (type(self) == type(other) and
self.__fn == other.__fn)
def __hash__(self):
return hash(type(self)) ^ hash(self.__fn)
def __str__(self):
return 'FromFunctionOp{%s}' % self.__fn.__name__
def make_node(self, *inputs):
return theano.Apply(self, self.itypes, self.otypes)
def perform(self, node, inputs, outputs):
outs = self.__fn(*inputs)
if not isinstance(outs, (list, tuple)):
outs = (outs,)
assert len(outs) == len(outputs)
for i in range(len(outs)):
outputs[i][0] = outs[i]
def as_op(itypes, otypes):
"""
Decorator that converts a function into a basic theano op that
will call the supplied function as its implementation.
This should not be used when performance is a concern since the
very basic nature of the resulting Op may interfere with certain
graph optimizations.
Example usage:
@as_op(itypes=[theano.tensor.fmatrix(), theano.tensor.fmatrix()],
otypes=[theano.tensor.fmatrix()])
def numpy_dot(a, b):
return numpy.dot(a, b)
"""
if (not isinstance(itypes, (list, tuple)) or
any(not isinstance(t, theano.Type) for t in itypes)):
raise TypeError("itypes has to be a list of theano types")
if (not isinstance(otypes, (list, tuple)) or
any(not isinstance(t, theano.Type) for t in otypes)):
raise TypeError("otypes has to be a list of theano types")
itypes = list(itypes)
otypes = list(otypes)
def make_op(fn):
return FromFunctionOp(fn, itypes, otypes)
return make_op
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论