提交 3e84371d authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3071 from harlouci/eickenberg-mgrid

Finished mgrid and ogrid
...@@ -1643,8 +1643,49 @@ Linear Algebra ...@@ -1643,8 +1643,49 @@ Linear Algebra
:note: See :func:`tensordot` and :func:`batched_dot` for :note: See :func:`tensordot` and :func:`batched_dot` for
supplementary documentation. supplementary documentation.
.. function:: mgrid
:returns: an instance which returns a dense (or fleshed out) mesh-grid
when indexed, so that each returned argument has the same shape.
The dimensions and number of the output arrays are equal to the
number of indexing dimensions. If the step length is not a complex
number, then the stop is not inclusive.
Example:
>>> a = T.mgrid[0:5, 0:3]
>>> a[0].eval()
array([[0, 0, 0],
[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]], dtype=int8)
>>> a[1].eval()
array([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]], dtype=int8)
.. function:: ogrid
:returns: an instance which returns an open (i.e. not fleshed out) mesh-grid
when indexed, so that only one dimension of each returned array is
greater than 1. The dimension and number of the output arrays are
equal to the number of indexing dimensions. If the step length is
not a complex number, then the stop is not inclusive.
Example:
>>> b = T.ogrid[0:5, 0:3]
>>> b[0].eval()
array([[0],
[1],
[2],
[3],
[4]], dtype=int8)
>>> b[1].eval()
array([[0, 1, 2, 3]], dtype=int8)
Gradient / Differentiation Gradient / Differentiation
========================== ==========================
......
...@@ -4585,6 +4585,80 @@ def arange(start, stop=None, step=1, dtype=None): ...@@ -4585,6 +4585,80 @@ def arange(start, stop=None, step=1, dtype=None):
return _arange[dtype](start, stop, step) return _arange[dtype](start, stop, step)
class _nd_grid(object):
"""Create a dense n-dimensional 'meshgrid' with equally spaced points.
Used to create the instance ``mgrid`` and ``ogrid`` which act similarly
to their numpy equivalents.
Parameters
==========
sparse : boolean, optional, default=True
Specifying False leads to the equivalent of numpy's mgrid
functionality. Specifying True leads to the equivalent of ogrid.
Examples
========
>>> a = T.mgrid[0:5, 0:3]
>>> a[0].eval()
array([[0, 0, 0],
[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]], dtype=int8)
>>> a[1].eval()
array([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]], dtype=int8)
>>> b = T.ogrid[0:5, 0:3]
>>> b[0].eval()
array([[0],
[1],
[2],
[3],
[4]], dtype=int8)
>>> b[1].eval()
array([[0, 1, 2, 3]], dtype=int8)
"""
def __init__(self, sparse=False):
self.sparse = sparse
def __getitem__(self, *args):
ndim = len(args[0])
for sl in args[0]:
if isinstance(sl.step, python_complex):
raise NotImplementedError("Not implemented for slices "
"whose step is complex")
ranges = [arange(sl.start or 0,
sl.stop,
sl.step or 1) for sl in args[0]]
shapes = [tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j))
for j, r in enumerate(ranges)]
ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes)]
if self.sparse:
grids = ranges
else:
grids = []
ones = [ones_like(r) for r in ranges]
for i in range(ndim):
grid = 1
for j in range(ndim):
if j == i:
grid = grid * ranges[j]
else:
grid = grid * ones[j]
grids.append(grid)
return grids
mgrid = _nd_grid()
ogrid = _nd_grid(sparse=True)
class PermuteRowElements(Op): class PermuteRowElements(Op):
"""Permute the elements of each row (inner-most dim) of a tensor. """Permute the elements of each row (inner-most dim) of a tensor.
......
...@@ -48,7 +48,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -48,7 +48,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
nonzero, flatnonzero, nonzero_values, nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle, hessian, ptp, power, stacklists, DimShuffle, hessian, ptp, power,
swapaxes, choose, Choose, NoneConst, AllocEmpty, swapaxes, choose, Choose, NoneConst, AllocEmpty,
isclose, allclose, isclose, allclose, mgrid, ogrid,
) )
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -5491,6 +5491,56 @@ class TestARange(unittest.TestCase): ...@@ -5491,6 +5491,56 @@ class TestARange(unittest.TestCase):
assert numpy.all(f(0) == len(numpy.arange(0, 0))) assert numpy.all(f(0) == len(numpy.arange(0, 0)))
class TestNdGrid(unittest.TestCase):
def setUp(self):
pass
def test_mgrid_numpy_equiv(self):
nmgrid = (numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.],
numpy.mgrid[0:2:1, 1:10:1, 10:100:10])
tmgrid = (mgrid[0:1:.1, 1:10:1., 10:100:10.],
mgrid[0:2:1, 1:10:1, 10:100:10])
for n, t in zip(nmgrid, tmgrid):
for ng, tg in zip(n, t):
assert_array_equal(ng, tg.eval())
def test_ogrid_numpy_equiv(self):
nogrid = (numpy.ogrid[0:1:.1, 1:10:1., 10:100:10.],
numpy.ogrid[0:2:1, 1:10:1, 10:100:10])
togrid = (ogrid[0:1:.1, 1:10:1., 10:100:10.],
ogrid[0:2:1, 1:10:1, 10:100:10])
for n, t in zip(nogrid, togrid):
for ng, tg in zip(n, t):
assert_array_equal(ng, tg.eval())
def test_mgrid_theano_variable_numpy_equiv(self):
nfmgrid = numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.]
nimgrid = numpy.mgrid[0:2:1, 1:10:1, 10:100:10]
i,j,k = dscalars('i','j','k')
l,m,n = iscalars('l','m','n')
tfmgrid = mgrid[i:1:.1, 1:j:1., 10:100:k]
timgrid = mgrid[l:2:1, 1:m:1, 10:100:n]
ff = theano.function([i, j, k], tfmgrid)
fi = theano.function([l, m, n], timgrid)
for n, t in zip((nfmgrid,nimgrid), (ff(0, 10, 10.),fi(0, 10, 10))):
for ng, tg in zip(n, t):
assert_array_equal(ng, tg)
def test_ogrid_theano_variable_numpy_equiv(self):
nfogrid = numpy.ogrid[0:1:.1, 1:10:1., 10:100:10.]
niogrid = numpy.ogrid[0:2:1, 1:10:1, 10:100:10]
i,j,k = dscalars('i','j','k')
l,m,n = iscalars('l','m','n')
tfogrid = ogrid[i:1:.1, 1:j:1., 10:100:k]
tiogrid = ogrid[l:2:1, 1:m:1, 10:100:n]
ff = theano.function([i, j, k], tfogrid)
fi = theano.function([l, m, n], tiogrid)
for n, t in zip((nfogrid,niogrid), (ff(0, 10, 10.),fi(0, 10, 10))):
for ng, tg in zip(n, t):
assert_array_equal(ng, tg)
class TestInversePermutation(unittest.TestCase): class TestInversePermutation(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论