提交 1b213675 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

Ammendements provisoires de NdGrid et TestNdGrid

上级 09773611
......@@ -4639,17 +4639,21 @@ class _nd_grid(object):
def __getitem__(self, *args):
ndim = len(args[0])
for sl in args[0]:
if isinstance(sl.step, python_complex):
raise NotImplementedError("Not implemented for slices "
"whose step is complex")
ranges = [arange(sl.start or 0,
sl.stop or None,
sl.step or 1) for sl in args[0]]
shapes = [tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j))
for j, r in enumerate(ranges)]
ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes)]
ones = [ones_like(r) for r in ranges]
if self.sparse:
grids = ranges
else:
grids = []
ones = [ones_like(r) for r in ranges]
for i in range(ndim):
grid = 1
for j in range(ndim):
......
......@@ -5485,31 +5485,43 @@ class TestNdGrid(unittest.TestCase):
def setUp(self):
pass
def test_mgrid_numpy_equiv_float(self):
def test_mgrid_numpy_equiv(self):
nmgrid = (numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.],
numpy.mgrid[0:2:1, 1:10:1, 10:100:10])
tmgrid = (mgrid[0:1:.1, 1:10:1., 10:100:10.],
mgrid[0:2:1, 1:10:1, 10:100:10])
for n, g in zip(nmgrid, tmgrid):
for ng, tg in zip(n, g):
assert_array_equal(ng, tg.eval())
def test_ogrid_numpy_equiv(self):
nogrid = (numpy.ogrid[0:1:.1, 1:10:1., 10:100:10.],
numpy.ogrid[0:2:1, 1:10:1, 10:100:10])
togrid = (ogrid[0:1:.1, 1:10:1., 10:100:10.],
ogrid[0:2:1, 1:10:1, 10:100:10])
for n, g in zip(nogrid, togrid):
for ng, tg in zip(n, g):
assert_array_equal(ng, tg.eval())
def test_mgrid_theano_variable_numpy_equiv_float(self):
nfmgrid = numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.]
tfmgrid = mgrid[0:1:.1, 1:10:1., 10:100:10.]
for ng, tg in zip(nfmgrid, tfmgrid):
assert_array_equal(ng, tg.eval())
def test_mgrid_numpy_equiv_int(self):
i,j,k = dscalars('i','j','k')
tfmgrid = mgrid[i:1:.1, 1:j:1., 10:100:k]
f = theano.function([i,j,k], tfmgrid)
for ng, tg in zip(nfmgrid, f(0,10,10.)):
assert_array_equal(ng, tg)
def test_mgrid_theano_variable_numpy_equiv_int(self):
nimgrid = numpy.mgrid[0:2:1, 1:10:1, 10:100:10]
timgrid = mgrid[0:2:1, 1:10:1, 10:100:10]
for ng, tg in zip(nimgrid, timgrid):
assert_array_equal(ng, tg.eval())
def test_ogrid_numpy_equiv_float(self):
nfogrid = numpy.ogrid[0:1:.1, 1:10:1., 10:100:10.]
tfogrid = ogrid[0:1:.1, 1:10:1., 10:100:10.]
for ng, tg in zip(nfogrid, tfogrid):
assert_array_equal(ng, tg.eval())
def test_ogrid_numpy_equiv_int(self):
niogrid = numpy.ogrid[0:2:1, 1:10:1, 10:100:10]
tiogrid = ogrid[0:2:1, 1:10:1, 10:100:10]
for ng, tg in zip(niogrid, tiogrid):
assert_array_equal(ng, tg.eval())
i,j,k = iscalars('i','j','k')
timgrid = mgrid[i:2:1, 1:j:1, 10:100:k]
f = theano.function([i,j,k], timgrid)
for ng, tg in zip(nimgrid, f(0,10,10)):
assert_array_equal(ng, tg)
# def test_mgrid_theano_variable_numpy_equiv(self):
# nmgrid = (numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.],
# numpy.mgrid[0:2:1, 1:10:1, 10:100:10])
class TestInversePermutation(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论