提交 6c288a2a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2021 from abergeron/pickle_as_op

Make as_op instances pickleable.
...@@ -402,6 +402,12 @@ def register_shape_i_c_code(typ, code, check_input, version=()): ...@@ -402,6 +402,12 @@ def register_shape_i_c_code(typ, code, check_input, version=()):
# Scan can deal with. # Scan can deal with.
expandable_types = () expandable_types = ()
def load_back(mod, name):
__import__(mod)
import sys
module = sys.modules[mod]
obj = getattr(module, name)
return obj
class FromFunctionOp(gof.Op): class FromFunctionOp(gof.Op):
""" """
...@@ -447,6 +453,20 @@ class FromFunctionOp(gof.Op): ...@@ -447,6 +453,20 @@ class FromFunctionOp(gof.Op):
for i in range(len(outs)): for i in range(len(outs)):
outputs[i][0] = outs[i] outputs[i][0] = outs[i]
def __reduce__(self):
mod = self.__fn.__module__
name = self.__fn.__name__
try:
obj = load_back(mod, name)
except (ImportError, KeyError, AttributeError):
raise PicklingError("Can't pickle as_op(), not found as %s.%s" %
(mod, name))
else:
if obj is not self:
raise PicklingError("Can't pickle as_op(), not the object "
"at %s.%s" % (mod, name))
return load_back, (mod, name)
def _infer_shape(self, node, input_shapes): def _infer_shape(self, node, input_shapes):
return self.__infer_shape(node, input_shapes) return self.__infer_shape(node, input_shapes)
......
...@@ -10,8 +10,15 @@ from theano import tensor ...@@ -10,8 +10,15 @@ from theano import tensor
from theano.tensor import dmatrix, dvector from theano.tensor import dmatrix, dvector
from numpy import allclose from numpy import allclose
from theano.compile import as_op from theano.compile import as_op
import pickle
# This is for test_pickle, since the function still has to be
# reachable from pickle (as in it cannot be defined inline)
@as_op([dmatrix, dmatrix], dmatrix)
def mul(a, b):
return a*b
class OpDecoratorTests(utt.InferShapeTester): class OpDecoratorTests(utt.InferShapeTester):
def test_1arg(self): def test_1arg(self):
x = dmatrix('x') x = dmatrix('x')
...@@ -59,3 +66,14 @@ class OpDecoratorTests(utt.InferShapeTester): ...@@ -59,3 +66,14 @@ class OpDecoratorTests(utt.InferShapeTester):
self._compile_and_check([x, y], [diag_mult(x, y)], self._compile_and_check([x, y], [diag_mult(x, y)],
[[[1.5, 5], [2, 2]], [1, 100]], [[[1.5, 5], [2, 2]], [1, 100]],
diag_mult.__class__, warn=False) diag_mult.__class__, warn=False)
def test_pickle(self):
x = dmatrix('x')
y = dmatrix('y')
m = mul(x, y)
s = pickle.dumps(m)
m2 = pickle.loads(s)
assert m2.owner.op == m.owner.op
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论