提交 1b213675 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

Ammendements provisoires de NdGrid et TestNdGrid

上级 09773611
...@@ -4639,17 +4639,21 @@ class _nd_grid(object): ...@@ -4639,17 +4639,21 @@ class _nd_grid(object):
def __getitem__(self, *args): def __getitem__(self, *args):
ndim = len(args[0]) ndim = len(args[0])
for sl in args[0]:
if isinstance(sl.step, python_complex):
raise NotImplementedError("Not implemented for slices "
"whose step is complex")
ranges = [arange(sl.start or 0, ranges = [arange(sl.start or 0,
sl.stop or None, sl.stop or None,
sl.step or 1) for sl in args[0]] sl.step or 1) for sl in args[0]]
shapes = [tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j)) shapes = [tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j))
for j, r in enumerate(ranges)] for j, r in enumerate(ranges)]
ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes)] ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes)]
ones = [ones_like(r) for r in ranges]
if self.sparse: if self.sparse:
grids = ranges grids = ranges
else: else:
grids = [] grids = []
ones = [ones_like(r) for r in ranges]
for i in range(ndim): for i in range(ndim):
grid = 1 grid = 1
for j in range(ndim): for j in range(ndim):
......
...@@ -5485,31 +5485,43 @@ class TestNdGrid(unittest.TestCase): ...@@ -5485,31 +5485,43 @@ class TestNdGrid(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
def test_mgrid_numpy_equiv_float(self): def test_mgrid_numpy_equiv(self):
nmgrid = (numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.],
numpy.mgrid[0:2:1, 1:10:1, 10:100:10])
tmgrid = (mgrid[0:1:.1, 1:10:1., 10:100:10.],
mgrid[0:2:1, 1:10:1, 10:100:10])
for n, g in zip(nmgrid, tmgrid):
for ng, tg in zip(n, g):
assert_array_equal(ng, tg.eval())
def test_ogrid_numpy_equiv(self):
nogrid = (numpy.ogrid[0:1:.1, 1:10:1., 10:100:10.],
numpy.ogrid[0:2:1, 1:10:1, 10:100:10])
togrid = (ogrid[0:1:.1, 1:10:1., 10:100:10.],
ogrid[0:2:1, 1:10:1, 10:100:10])
for n, g in zip(nogrid, togrid):
for ng, tg in zip(n, g):
assert_array_equal(ng, tg.eval())
def test_mgrid_theano_variable_numpy_equiv_float(self):
nfmgrid = numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.] nfmgrid = numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.]
tfmgrid = mgrid[0:1:.1, 1:10:1., 10:100:10.] i,j,k = dscalars('i','j','k')
for ng, tg in zip(nfmgrid, tfmgrid): tfmgrid = mgrid[i:1:.1, 1:j:1., 10:100:k]
assert_array_equal(ng, tg.eval()) f = theano.function([i,j,k], tfmgrid)
for ng, tg in zip(nfmgrid, f(0,10,10.)):
def test_mgrid_numpy_equiv_int(self): assert_array_equal(ng, tg)
def test_mgrid_theano_variable_numpy_equiv_int(self):
nimgrid = numpy.mgrid[0:2:1, 1:10:1, 10:100:10] nimgrid = numpy.mgrid[0:2:1, 1:10:1, 10:100:10]
timgrid = mgrid[0:2:1, 1:10:1, 10:100:10] i,j,k = iscalars('i','j','k')
for ng, tg in zip(nimgrid, timgrid): timgrid = mgrid[i:2:1, 1:j:1, 10:100:k]
assert_array_equal(ng, tg.eval()) f = theano.function([i,j,k], timgrid)
for ng, tg in zip(nimgrid, f(0,10,10)):
def test_ogrid_numpy_equiv_float(self): assert_array_equal(ng, tg)
nfogrid = numpy.ogrid[0:1:.1, 1:10:1., 10:100:10.]
tfogrid = ogrid[0:1:.1, 1:10:1., 10:100:10.] # def test_mgrid_theano_variable_numpy_equiv(self):
for ng, tg in zip(nfogrid, tfogrid): # nmgrid = (numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.],
assert_array_equal(ng, tg.eval()) # numpy.mgrid[0:2:1, 1:10:1, 10:100:10])
def test_ogrid_numpy_equiv_int(self):
niogrid = numpy.ogrid[0:2:1, 1:10:1, 10:100:10]
tiogrid = ogrid[0:2:1, 1:10:1, 10:100:10]
for ng, tg in zip(niogrid, tiogrid):
assert_array_equal(ng, tg.eval())
class TestInversePermutation(unittest.TestCase): class TestInversePermutation(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论