提交 c7702a5e authored 作者: Frederic's avatar Frederic

Fix pickling of GpuIncSubtensor

上级 77f31aba
......@@ -181,16 +181,21 @@ class GpuIncSubtensor(IncSubtensor):
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
rval = tensor.IncSubtensor.make_node(self, x, y, *inputs)
op = copy.copy(self)
ret = gof.Apply(op, [x, y] + rval.inputs[2:], [x.type()])
op.create_iadd_node(ret)
return ret
def create_iadd_node(self, node):
# We store a iadd_node in the op that contain the info needed
# for the inplace add.
cop = theano.tensor.inplace.add_inplace
gop = GpuElemwise(cop.scalar_op, copy.copy(cop.inplace_pattern),
"Gpu" + cop.name, cop.nfunc_spec)
y = node.inputs[1]
xview = y.type()
iadd_node = gop(xview, y).owner
op = copy.copy(self)
op.iadd_node = iadd_node
return gof.Apply(op, [x, y] + rval.inputs[2:], [x.type()])
self.iadd_node = iadd_node
def perform(self, node, inputs, out_):
out, = out_
......@@ -232,6 +237,18 @@ class GpuIncSubtensor(IncSubtensor):
x.__setitem__(cdata, y)
out[0] = x
def __setstate__(self, d):
self.__dict__.update(d)
owner = getattr(self.__dict__, "owner", None)
if owner:
op.create_iadd_node(owner)
def __getstate__(self):
d = copy.copy(self.__dict__)
if "iadd_node" in d:
d.pop('iadd_node')
return d
def do_type_checking(self, node):
""" Should raise NotImplementedError if c_code does not support
the types involved in this node.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论