提交 ca52a0fc authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: op Tile

上级 4a664c8d
...@@ -5527,6 +5527,17 @@ class Tile(Op): ...@@ -5527,6 +5527,17 @@ class Tile(Op):
if len(out[0].shape) != self.ndim: if len(out[0].shape) != self.ndim:
raise ValueError('Tile.perform produced incorrect shape') raise ValueError('Tile.perform produced incorrect shape')
def infer_shape(self, node, in_shapes):
# Note: in contrast with numpy, it is assumed that x.shape and reps
# have equal length; see alsor tile function below
x, reps = node.inputs
shp = x.shape
tiled_shp = shp * reps
out_shape = []
for i in range(self.ndim):
out_shape.append(tiled_shp[i])
return [out_shape]
def grad(self, inp, grads): def grad(self, inp, grads):
x, reps = inp x, reps = inp
g_out, = grads g_out, = grads
...@@ -5539,21 +5550,24 @@ def tile(x, reps, ndim=None): ...@@ -5539,21 +5550,24 @@ def tile(x, reps, ndim=None):
Tile input array `x` according to `reps`. See the docstring of `numpy.tile` Tile input array `x` according to `reps`. See the docstring of `numpy.tile`
for details. for details.
Currently, `reps` must be a constant. Currently, `reps` must be a constant, x.ndim and len(reps) must be equal
and, if specified, 'ndim' must be equal to both.
TODO: expand this. TODO: expand this.
""" """
if len(reps) != x.ndim:
if isinstance(reps, theano.tensor.TensorVariable):
raise ValueError("'reps' argument to 'tile' must be a constant (e.g. "
"tuple, list of integers)")
elif len(reps) != x.ndim:
raise ValueError("len(reps) != x.ndim not currently supported") raise ValueError("len(reps) != x.ndim not currently supported")
elif (ndim is not None) and ndim != x.ndim:
raise ValueError("if specified, ndim must be equal to both x.ndim and "
"len(reps)")
if not hasattr(tile, 'op'): if not hasattr(tile, 'op'):
tile.op = {} tile.op = {}
try:
assert python_all([int(i) == i for i in iter(reps)])
except (TypeError, AssertionError):
raise ValueError("reps argument to tile must be a constant (e.g. "
"tuple, list of integers)")
if ndim is None: if ndim is None:
ndim = len(reps) ndim = len(reps)
......
...@@ -37,7 +37,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -37,7 +37,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements, tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc, ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1, dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
itensor3) itensor3, Tile)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6625,6 +6625,31 @@ shapes generated: ...@@ -6625,6 +6625,31 @@ shapes generated:
[Reshape(ndim, '_name')(admat, aivec)], [Reshape(ndim, '_name')(admat, aivec)],
[admat_val, [4, 3]], Reshape) [admat_val, [4, 3]], Reshape)
# Tile: basic 5292
advec = dvector()
advec_val = rand(5)
aivec_val = [3]
ndim = 1
self._compile_and_check([advec],
[tile(advec, aivec_val, ndim)],
[advec_val], Tile)
admat = dmatrix()
admat_val = rand(2, 4)
aivec_val = [2, 3]
ndim = None
self._compile_and_check([admat],
[tile(admat, aivec_val)],
[admat_val], Tile)
adtens4 = dtensor4()
adtens4_val = rand(2, 4, 3, 5)
aivec_val = [2, 3, 1, 4]
ndim = 4
self._compile_and_check([adtens4],
[tile(adtens4, aivec_val, ndim)],
[adtens4_val], Tile)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论