提交 2fcafab8 authored 作者: Frederic's avatar Frederic

Fix a crash as unpickle break our singleton pattern

上级 4d8536b5
...@@ -417,6 +417,19 @@ class SingletonType(Type): ...@@ -417,6 +417,19 @@ class SingletonType(Type):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
# even if we try to make a singleton, this do not always work. So
# we compare the type. See test_type_other.test_none_Constant for
# an exmple. So we need to implement __eq__ and __hash__
def __eq__(self, other):
if self is other:
return True
if type(self) is type(other):
return True
return False
def __hash__(self):
return hash(type(self))
class Generic(SingletonType): class Generic(SingletonType):
""" """
......
""" This file don't test everything. It only test one past crash error.""" """ This file don't test everything. It only test one past crash error."""
import theano import theano
from theano.tensor.type_other import MakeSlice, make_slice from theano.gof import Constant
from theano.tensor.type_other import MakeSlice, make_slice, NoneTypeT, NoneConst
def test_make_slice_merge(): def test_make_slice_merge():
...@@ -11,4 +12,28 @@ def test_make_slice_merge(): ...@@ -11,4 +12,28 @@ def test_make_slice_merge():
f = theano.function([i], [s1, s2]) f = theano.function([i], [s1, s2])
nodes = f.maker.fgraph.nodes nodes = f.maker.fgraph.nodes
assert len([n for n in nodes if isinstance(n.op, MakeSlice)]) == 1 assert len([n for n in nodes if isinstance(n.op, MakeSlice)]) == 1
theano.printing.debugprint(f) theano.printing.debugprint(f)
\ No newline at end of file
def test_none_Constant():
""" Tests equals
We had an error in the past with unpickling
"""
o1 = Constant(NoneTypeT(), None, name='NoneConst')
o2 = Constant(NoneTypeT(), None, name='NoneConst')
assert o1.equals(o2)
assert NoneConst.equals(o1)
assert o1.equals(NoneConst)
assert NoneConst.equals(o2)
assert o2.equals(NoneConst)
# This trigger equals that returned the wrong answer in the past.
import cPickle
import theano
from theano import tensor
x = tensor.vector('x')
y = tensor.argmax(x)
f = theano.function([x], [y])
cPickle.loads(cPickle.dumps(f))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论