提交 0281542f authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3550 from lamblin/fix_gemm_unpickle

Fix unpickling of Gemm and GpuGemm Ops
...@@ -401,7 +401,9 @@ class GpuGemm(GpuOp): ...@@ -401,7 +401,9 @@ class GpuGemm(GpuOp):
""" """
def __init__(self, inplace): def __init__(self, inplace):
self.__setstate__({'inplace': inplace}) self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
...@@ -417,13 +419,14 @@ class GpuGemm(GpuOp): ...@@ -417,13 +419,14 @@ class GpuGemm(GpuOp):
return hash(type(self)) ^ hash(self.inplace) return hash(type(self)) ^ hash(self.inplace)
def __setstate__(self, dct): def __setstate__(self, dct):
inplace = dct.get('inplace', True) self.__dict__.update(dct)
if inplace:
self.destroy_map = {0: [0]}
self.inplace = inplace
def __getstate__(self): # Correctly reload older pickles where _op_use_c_code and
return dict(inplace=self.inplace) # destroy_map were not saved
if '_op_use_c_code' not in self.__dict__:
self._op_use_c_code = theano.config.cxx
if 'destroy_map' not in self.__dict__ and self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, z, a, x, y, b): def make_node(self, z, a, x, y, b):
# the more complicated error checking performed by tensor.gemm # the more complicated error checking performed by tensor.gemm
...@@ -518,7 +521,9 @@ class GpuGemv(GpuOp): ...@@ -518,7 +521,9 @@ class GpuGemv(GpuOp):
""" """
def __init__(self, inplace): def __init__(self, inplace):
self.__setstate__({'inplace': inplace}) self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
...@@ -534,13 +539,14 @@ class GpuGemv(GpuOp): ...@@ -534,13 +539,14 @@ class GpuGemv(GpuOp):
return hash(type(self)) ^ hash(self.inplace) return hash(type(self)) ^ hash(self.inplace)
def __setstate__(self, dct): def __setstate__(self, dct):
inplace = dct.get('inplace', True) self.__dict__.update(dct)
if inplace:
self.destroy_map = {0: [0]}
self.inplace = inplace
def __getstate__(self): # Correctly reload older pickles where _op_use_c_code and
return dict(inplace=self.inplace) # destroy_map were not saved
if '_op_use_c_code' not in self.__dict__:
self._op_use_c_code = theano.config.cxx
if 'destroy_map' not in self.__dict__ and self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, z, a, x, y, b): def make_node(self, z, a, x, y, b):
# the more complicated error checking performed by tensor.gemv # the more complicated error checking performed by tensor.gemv
...@@ -615,7 +621,9 @@ class GpuGer(GpuOp): ...@@ -615,7 +621,9 @@ class GpuGer(GpuOp):
""" """
def __init__(self, inplace): def __init__(self, inplace):
self.__setstate__({'inplace': inplace}) self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
...@@ -631,13 +639,14 @@ class GpuGer(GpuOp): ...@@ -631,13 +639,14 @@ class GpuGer(GpuOp):
return hash(type(self)) ^ hash(self.inplace) return hash(type(self)) ^ hash(self.inplace)
def __setstate__(self, dct): def __setstate__(self, dct):
inplace = dct.get('inplace', True) self.__dict__.update(dct)
if inplace:
self.destroy_map = {0: [0]}
self.inplace = inplace
def __getstate__(self): # Correctly reload older pickles where _op_use_c_code and
return dict(inplace=self.inplace) # destroy_map were not saved
if '_op_use_c_code' not in self.__dict__:
self._op_use_c_code = theano.config.cxx
if 'destroy_map' not in self.__dict__ and self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, z, a, x, y): def make_node(self, z, a, x, y):
# the more complicated error checking performed by tensor.ger is # the more complicated error checking performed by tensor.ger is
......
...@@ -962,7 +962,12 @@ class Gemm(GemmRelated): ...@@ -962,7 +962,12 @@ class Gemm(GemmRelated):
E_float = 'gemm requires floating-point dtypes' E_float = 'gemm requires floating-point dtypes'
def __init__(self, inplace): def __init__(self, inplace):
self.__setstate__({'inplace': inplace}) self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (type(self) == type(other) and
...@@ -979,16 +984,25 @@ class Gemm(GemmRelated): ...@@ -979,16 +984,25 @@ class Gemm(GemmRelated):
return '%s{%s}' % (self.__class__.__name__, inplace_str) return '%s{%s}' % (self.__class__.__name__, inplace_str)
def __setstate__(self, dct): def __setstate__(self, dct):
inplace = dct.get('inplace', True) self.__dict__.update(dct)
if inplace: if self.inplace:
self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else: else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
self.inplace = inplace
# Correctly reload older pickles where _op_use_c_code and
# destroy_map were not saved
if '_op_use_c_code' not in self.__dict__:
self._op_use_c_code = theano.config.cxx
if 'destroy_map' not in self.__dict__ and self.inplace:
self.destroy_map = {0: [0]}
def __getstate__(self): def __getstate__(self):
return dict(inplace=self.inplace) rval = self.__dict__.copy()
# Do not serialize the setup code, it will be restored in __setstate__
# depending on the value of 'inplace'
rval.pop('setup_z_Nz_Sz')
return rval
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(T.as_tensor_variable, inputs)) inputs = list(map(T.as_tensor_variable, inputs))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论