提交 3c237906 authored 作者: amrithasuresh's avatar amrithasuresh

1. Updated numpy as np

2. Fixed indentation
上级 1807d925
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import unittest import unittest
import numpy import numpy as np
import theano import theano
from theano import function, config from theano import function, config
...@@ -28,7 +28,7 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -28,7 +28,7 @@ class T_max_and_argmax(unittest.TestCase):
'canonicalize', 'fast_run') 'canonicalize', 'fast_run')
for axis in [0, 1, -1]: for axis in [0, 1, -1]:
data = numpy.asarray(numpy.random.rand(2, 3), dtype=config.floatX) data = np.asarray(np.random.rand(2, 3), dtype=config.floatX)
n = tensor.matrix() n = tensor.matrix()
f = function([n], tensor.max_and_argmax(n, axis)[0], mode=mode) f = function([n], tensor.max_and_argmax(n, axis)[0], mode=mode)
...@@ -49,7 +49,7 @@ class T_min_max(unittest.TestCase): ...@@ -49,7 +49,7 @@ class T_min_max(unittest.TestCase):
'canonicalize', 'fast_run') 'canonicalize', 'fast_run')
def test_optimization_max(self): def test_optimization_max(self):
data = numpy.asarray(numpy.random.rand(2, 3), dtype=config.floatX) data = np.asarray(np.random.rand(2, 3), dtype=config.floatX)
n = tensor.matrix() n = tensor.matrix()
for axis in [0, 1, -1]: for axis in [0, 1, -1]:
...@@ -82,7 +82,7 @@ class T_min_max(unittest.TestCase): ...@@ -82,7 +82,7 @@ class T_min_max(unittest.TestCase):
f(data) f(data)
def test_optimization_min(self): def test_optimization_min(self):
data = numpy.asarray(numpy.random.rand(2, 3), dtype=config.floatX) data = np.asarray(np.random.rand(2, 3), dtype=config.floatX)
n = tensor.matrix() n = tensor.matrix()
for axis in [0, 1, -1]: for axis in [0, 1, -1]:
...@@ -206,4 +206,4 @@ def test_local_dimshuffle_subtensor(): ...@@ -206,4 +206,4 @@ def test_local_dimshuffle_subtensor():
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo]) assert any([not isinstance(x, DimShuffle) for x in topo])
assert f(numpy.random.rand(5, 1, 4, 1), 2).shape == (4,) assert f(np.random.rand(5, 1, 4, 1), 2).shape == (4,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论