提交 4bb2f3f6 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #206 from nouiz/fix_test

Fix test
......@@ -161,26 +161,29 @@ def test_rop_lop():
def test_det():
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
r = rng.randn(5, 5).astype(config.floatX)
x = tensor.matrix()
f = theano.function([x],det(x))
assert numpy.linalg.det(r) == f(r)
f = theano.function([x], det(x))
assert numpy.allclose(numpy.linalg.det(r), f(r))
def test_det_grad():
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
r = rng.randn(5, 5).astype(config.floatX)
tensor.verify_grad(det, [r], rng=numpy.random)
def test_det_shape():
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
r = rng.randn(5, 5).astype(config.floatX)
x = tensor.matrix()
f = theano.function([x],det(x))
f_shape = theano.function([x],det(x).shape)
f = theano.function([x], det(x))
f_shape = theano.function([x], det(x).shape)
assert numpy.all(f(r).shape == f_shape(r))
def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix()
......
......@@ -1124,6 +1124,12 @@ class TimesN(theano.scalar.basic.UnaryScalarOp):
Must be outside of the class, otherwise, the c cache code can't
pickle this class and this cause stuff printing during test.
"""
def __eq__(self, other):
return super(TimesN, self).__eq__(other) and self.n == other.n
def __hash__(self):
return super(TimesN, self).__hash__() ^ hash(self.n)
def __init__(self, n, *args, **kwargs):
self.n = n
theano.scalar.basic.UnaryScalarOp.__init__(self, *args, **kwargs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论