提交 09773611 authored 作者: Michael Eickenberg's avatar Michael Eickenberg 提交者: Iban Harlouchet

WIP added mgrid/ogrid functionality

上级 859bba60
......@@ -4595,6 +4595,76 @@ def arange(start, stop=None, step=1, dtype=None):
return _arange[dtype](start, stop, step)
class _nd_grid(object):
"""Create a dense n-dimensional 'meshgrid' with equally spaced points.
Used to create the instance ``mgrid`` and ``ogrid`` which act similarly
to their numpy equivalents.
Parameters
==========
sparse : boolean, optional, default=True
Specifying False leads to the equivalent of numpy's mgrid
functionality. Specifying True leads to the equivalent of ogrid.
Examples
========
>>> a = T.mgrid[0:5, 0:3]
>>> a[0].eval()
array([[0, 0, 0],
[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]], dtype=int8)
>>> a[1].eval()
array([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]], dtype=int8)
>>> b = T.ogrid[0:5, 0:3]
>>> b[0].eval()
array([[0],
[1],
[2],
[3],
[4]], dtype=int8)
>>> b[1].eval()
array([[0, 1, 2, 3]], dtype=int8)
"""
def __init__(self, sparse=False):
self.sparse = sparse
def __getitem__(self, *args):
ndim = len(args[0])
ranges = [arange(sl.start or 0,
sl.stop or None,
sl.step or 1) for sl in args[0]]
shapes = [tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j))
for j, r in enumerate(ranges)]
ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes)]
ones = [ones_like(r) for r in ranges]
if self.sparse:
grids = ranges
else:
grids = []
for i in range(ndim):
grid = 1
for j in range(ndim):
if j == i:
grid = grid * ranges[j]
else:
grid = grid * ones[j]
grids.append(grid)
return grids
mgrid = _nd_grid()
ogrid = _nd_grid(sparse=True)
class PermuteRowElements(Op):
"""Permute the elements of each row (inner-most dim) of a tensor.
......
......@@ -48,7 +48,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle, hessian, ptp, power,
swapaxes, choose, Choose, NoneConst, AllocEmpty,
isclose, allclose,
isclose, allclose, mgrid, ogrid,
)
from theano.tests import unittest_tools as utt
......@@ -5480,6 +5480,37 @@ class TestARange(unittest.TestCase):
assert numpy.all(f(0) == len(numpy.arange(0, 0)))
class TestNdGrid(unittest.TestCase):
def setUp(self):
pass
def test_mgrid_numpy_equiv_float(self):
nfmgrid = numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.]
tfmgrid = mgrid[0:1:.1, 1:10:1., 10:100:10.]
for ng, tg in zip(nfmgrid, tfmgrid):
assert_array_equal(ng, tg.eval())
def test_mgrid_numpy_equiv_int(self):
nimgrid = numpy.mgrid[0:2:1, 1:10:1, 10:100:10]
timgrid = mgrid[0:2:1, 1:10:1, 10:100:10]
for ng, tg in zip(nimgrid, timgrid):
assert_array_equal(ng, tg.eval())
def test_ogrid_numpy_equiv_float(self):
nfogrid = numpy.ogrid[0:1:.1, 1:10:1., 10:100:10.]
tfogrid = ogrid[0:1:.1, 1:10:1., 10:100:10.]
for ng, tg in zip(nfogrid, tfogrid):
assert_array_equal(ng, tg.eval())
def test_ogrid_numpy_equiv_int(self):
niogrid = numpy.ogrid[0:2:1, 1:10:1, 10:100:10]
tiogrid = ogrid[0:2:1, 1:10:1, 10:100:10]
for ng, tg in zip(niogrid, tiogrid):
assert_array_equal(ng, tg.eval())
class TestInversePermutation(unittest.TestCase):
def setUp(self):
utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论