提交 5344866f authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make it possible to reload pickled Gemm

上级 146fbae8
......@@ -401,7 +401,9 @@ class GpuGemm(GpuOp):
"""
def __init__(self, inplace):
self.__setstate__({'inplace': inplace})
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def __str__(self):
if self.inplace:
......@@ -417,13 +419,14 @@ class GpuGemm(GpuOp):
return hash(type(self)) ^ hash(self.inplace)
def __setstate__(self, dct):
inplace = dct.get('inplace', True)
if inplace:
self.destroy_map = {0: [0]}
self.inplace = inplace
self.__dict__.update(dct)
def __getstate__(self):
return dict(inplace=self.inplace)
# Correctly reload older pickles where _op_use_c_code and
# destroy_map were not saved
if '_op_use_c_code' not in self.__dict__:
self._op_use_c_code = theano.config.cxx
if 'destroy_map' not in self.__dict__ and self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, z, a, x, y, b):
# the more complicated error checking performed by tensor.gemm
......@@ -518,7 +521,9 @@ class GpuGemv(GpuOp):
"""
def __init__(self, inplace):
self.__setstate__({'inplace': inplace})
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def __str__(self):
if self.inplace:
......@@ -534,13 +539,14 @@ class GpuGemv(GpuOp):
return hash(type(self)) ^ hash(self.inplace)
def __setstate__(self, dct):
inplace = dct.get('inplace', True)
if inplace:
self.destroy_map = {0: [0]}
self.inplace = inplace
self.__dict__.update(dct)
def __getstate__(self):
return dict(inplace=self.inplace)
# Correctly reload older pickles where _op_use_c_code and
# destroy_map were not saved
if '_op_use_c_code' not in self.__dict__:
self._op_use_c_code = theano.config.cxx
if 'destroy_map' not in self.__dict__ and self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, z, a, x, y, b):
# the more complicated error checking performed by tensor.gemv
......@@ -615,7 +621,9 @@ class GpuGer(GpuOp):
"""
def __init__(self, inplace):
self.__setstate__({'inplace': inplace})
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def __str__(self):
if self.inplace:
......@@ -631,13 +639,14 @@ class GpuGer(GpuOp):
return hash(type(self)) ^ hash(self.inplace)
def __setstate__(self, dct):
inplace = dct.get('inplace', True)
if inplace:
self.destroy_map = {0: [0]}
self.inplace = inplace
self.__dict__.update(dct)
def __getstate__(self):
return dict(inplace=self.inplace)
# Correctly reload older pickles where _op_use_c_code and
# destroy_map were not saved
if '_op_use_c_code' not in self.__dict__:
self._op_use_c_code = theano.config.cxx
if 'destroy_map' not in self.__dict__ and self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, z, a, x, y):
# the more complicated error checking performed by tensor.ger is
......
......@@ -962,7 +962,12 @@ class Gemm(GemmRelated):
E_float = 'gemm requires floating-point dtypes'
def __init__(self, inplace):
self.__setstate__({'inplace': inplace})
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
def __eq__(self, other):
return (type(self) == type(other) and
......@@ -979,16 +984,25 @@ class Gemm(GemmRelated):
return '%s{%s}' % (self.__class__.__name__, inplace_str)
def __setstate__(self, dct):
inplace = dct.get('inplace', True)
if inplace:
self.destroy_map = {0: [0]}
self.__dict__.update(dct)
if self.inplace:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
self.inplace = inplace
# Correctly reload older pickles where _op_use_c_code and
# destroy_map were not saved
if '_op_use_c_code' not in self.__dict__:
self._op_use_c_code = theano.config.cxx
if 'destroy_map' not in self.__dict__ and self.inplace:
self.destroy_map = {0: [0]}
def __getstate__(self):
return dict(inplace=self.inplace)
rval = self.__dict__.copy()
# Do not serialize the setup code, it will be restored in __setstate__
# depending on the value of 'inplace'
rval.pop('setup_z_Nz_Sz')
return rval
def make_node(self, *inputs):
inputs = list(map(T.as_tensor_variable, inputs))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论