提交 c42414b6 authored 作者: James Bergstra's avatar James Bergstra

Added __setstate__ and __getstate__ to Gemm so that old pickled class instances

can be unpickled properly even though Gemm has an 'inplace' property now.
上级 db26bcf0
...@@ -301,12 +301,7 @@ class Gemm(GemmRelated): ...@@ -301,12 +301,7 @@ class Gemm(GemmRelated):
E_scalar = 'gemm requires scalar argument' E_scalar = 'gemm requires scalar argument'
E_z_uniq = 'argument z aliased to x or y' E_z_uniq = 'argument z aliased to x or y'
def __init__(self, inplace): def __init__(self, inplace):
if inplace: self.__setstate__({'inplace':inplace})
self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
self.inplace = inplace
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)\ return (type(self) == type(other)\
...@@ -320,6 +315,18 @@ class Gemm(GemmRelated): ...@@ -320,6 +315,18 @@ class Gemm(GemmRelated):
else: inplace_str = 'no_inplace' else: inplace_str = 'no_inplace'
return '%s{%s}' % (self.__class__.__name__, inplace_str) return '%s{%s}' % (self.__class__.__name__, inplace_str)
def __setstate__(self, dct):
inplace = dct.get('inplace', True)
if inplace:
self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
self.inplace = inplace
def __getstate__(self):
return dict(inplace=self.inplace)
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = map(T.as_tensor_variable, inputs) inputs = map(T.as_tensor_variable, inputs)
if len(inputs) != 5: if len(inputs) != 5:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论