提交 3e8b9eee authored 作者: abergeron's avatar abergeron

Merge pull request #2517 from nouiz/crash_unpickle

Crash unpickle
......@@ -416,6 +416,19 @@ class SingletonType(Type):
def __str__(self):
return self.__class__.__name__
# even if we try to make a singleton, this do not always work. So
# we compare the type. See test_type_other.test_none_Constant for
# an exmple. So we need to implement __eq__ and __hash__
def __eq__(self, other):
if self is other:
return True
if type(self) is type(other):
return True
return False
def __hash__(self):
return hash(type(self))
class Generic(SingletonType):
"""
......
""" This file don't test everything. It only test one past crash error."""
import theano
from theano.tensor.type_other import MakeSlice, make_slice
from theano.gof import Constant
from theano.tensor.type_other import MakeSlice, make_slice, NoneTypeT, NoneConst
def test_make_slice_merge():
......@@ -11,4 +12,28 @@ def test_make_slice_merge():
f = theano.function([i], [s1, s2])
nodes = f.maker.fgraph.nodes
assert len([n for n in nodes if isinstance(n.op, MakeSlice)]) == 1
theano.printing.debugprint(f)
\ No newline at end of file
theano.printing.debugprint(f)
def test_none_Constant():
""" Tests equals
We had an error in the past with unpickling
"""
o1 = Constant(NoneTypeT(), None, name='NoneConst')
o2 = Constant(NoneTypeT(), None, name='NoneConst')
assert o1.equals(o2)
assert NoneConst.equals(o1)
assert o1.equals(NoneConst)
assert NoneConst.equals(o2)
assert o2.equals(NoneConst)
# This trigger equals that returned the wrong answer in the past.
import cPickle
import theano
from theano import tensor
x = tensor.vector('x')
y = tensor.argmax(x)
f = theano.function([x], [y])
cPickle.loads(cPickle.dumps(f))
......@@ -129,5 +129,5 @@ none_type_t = NoneTypeT()
# This is a variable instance. It can be used only once per fgraph.
# So use NoneConst.clone() before using it in a Theano graph.
# Use NoneConst.equal(x) to check if two variable are NoneConst.
NoneConst = Constant(NoneTypeT(), None, name='NoneConst')
# Use NoneConst.equals(x) to check if two variable are NoneConst.
NoneConst = Constant(none_type_t, None, name='NoneConst')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论