提交 3e8b9eee authored 作者: abergeron's avatar abergeron

Merge pull request #2517 from nouiz/crash_unpickle

Crash unpickle
...@@ -416,6 +416,19 @@ class SingletonType(Type): ...@@ -416,6 +416,19 @@ class SingletonType(Type):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
# even if we try to make a singleton, this do not always work. So
# we compare the type. See test_type_other.test_none_Constant for
# an exmple. So we need to implement __eq__ and __hash__
def __eq__(self, other):
if self is other:
return True
if type(self) is type(other):
return True
return False
def __hash__(self):
return hash(type(self))
class Generic(SingletonType): class Generic(SingletonType):
""" """
......
""" This file don't test everything. It only test one past crash error.""" """ This file don't test everything. It only test one past crash error."""
import theano import theano
from theano.tensor.type_other import MakeSlice, make_slice from theano.gof import Constant
from theano.tensor.type_other import MakeSlice, make_slice, NoneTypeT, NoneConst
def test_make_slice_merge(): def test_make_slice_merge():
...@@ -11,4 +12,28 @@ def test_make_slice_merge(): ...@@ -11,4 +12,28 @@ def test_make_slice_merge():
f = theano.function([i], [s1, s2]) f = theano.function([i], [s1, s2])
nodes = f.maker.fgraph.nodes nodes = f.maker.fgraph.nodes
assert len([n for n in nodes if isinstance(n.op, MakeSlice)]) == 1 assert len([n for n in nodes if isinstance(n.op, MakeSlice)]) == 1
theano.printing.debugprint(f) theano.printing.debugprint(f)
\ No newline at end of file
def test_none_Constant():
""" Tests equals
We had an error in the past with unpickling
"""
o1 = Constant(NoneTypeT(), None, name='NoneConst')
o2 = Constant(NoneTypeT(), None, name='NoneConst')
assert o1.equals(o2)
assert NoneConst.equals(o1)
assert o1.equals(NoneConst)
assert NoneConst.equals(o2)
assert o2.equals(NoneConst)
# This trigger equals that returned the wrong answer in the past.
import cPickle
import theano
from theano import tensor
x = tensor.vector('x')
y = tensor.argmax(x)
f = theano.function([x], [y])
cPickle.loads(cPickle.dumps(f))
...@@ -129,5 +129,5 @@ none_type_t = NoneTypeT() ...@@ -129,5 +129,5 @@ none_type_t = NoneTypeT()
# This is a variable instance. It can be used only once per fgraph. # This is a variable instance. It can be used only once per fgraph.
# So use NoneConst.clone() before using it in a Theano graph. # So use NoneConst.clone() before using it in a Theano graph.
# Use NoneConst.equal(x) to check if two variable are NoneConst. # Use NoneConst.equals(x) to check if two variable are NoneConst.
NoneConst = Constant(NoneTypeT(), None, name='NoneConst') NoneConst = Constant(none_type_t, None, name='NoneConst')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论