提交 17521fd3 authored 作者: amrithasuresh's avatar amrithasuresh

Updated numpy as np

上级 1f5087d2
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import numpy import numpy as np
import unittest import unittest
import theano import theano
...@@ -39,14 +39,14 @@ class TestScanCheckpoint(unittest.TestCase): ...@@ -39,14 +39,14 @@ class TestScanCheckpoint(unittest.TestCase):
f = theano.function(inputs=[self.A, self.k], f = theano.function(inputs=[self.A, self.k],
outputs=[self.result, self.result_check]) outputs=[self.result, self.result_check])
out, out_check = f(range(10), 101) out, out_check = f(range(10), 101)
assert numpy.allclose(out, out_check) assert np.allclose(out, out_check)
def test_backward_pass(self): def test_backward_pass(self):
"""Test gradient computation of A**k.""" """Test gradient computation of A**k."""
f = theano.function(inputs=[self.A, self.k], f = theano.function(inputs=[self.A, self.k],
outputs=[self.grad_A, self.grad_A_check]) outputs=[self.grad_A, self.grad_A_check])
out, out_check = f(range(10), 101) out, out_check = f(range(10), 101)
assert numpy.allclose(out, out_check) assert np.allclose(out, out_check)
@unittest.skipUnless(PYGPU_AVAILABLE, 'Requires pygpu.') @unittest.skipUnless(PYGPU_AVAILABLE, 'Requires pygpu.')
def test_memory(self): def test_memory(self):
...@@ -59,7 +59,7 @@ class TestScanCheckpoint(unittest.TestCase): ...@@ -59,7 +59,7 @@ class TestScanCheckpoint(unittest.TestCase):
f_check = theano.function(inputs=[self.A, self.k], f_check = theano.function(inputs=[self.A, self.k],
outputs=self.grad_A_check, mode=mode_with_gpu) outputs=self.grad_A_check, mode=mode_with_gpu)
free_gmem = theano.gpuarray.type._context_reg[None].free_gmem free_gmem = theano.gpuarray.type._context_reg[None].free_gmem
data = numpy.ones(free_gmem // 3000, dtype=numpy.float32) data = np.ones(free_gmem // 3000, dtype=np.float32)
# Check that it works with the checkpoints # Check that it works with the checkpoints
f_check(data, 1000) f_check(data, 1000)
# Check that the basic scan fails in that case # Check that the basic scan fails in that case
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论