提交 ca52a0fc authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: op Tile

上级 4a664c8d
......@@ -5527,6 +5527,17 @@ class Tile(Op):
if len(out[0].shape) != self.ndim:
raise ValueError('Tile.perform produced incorrect shape')
def infer_shape(self, node, in_shapes):
# Note: in contrast with numpy, it is assumed that x.shape and reps
# have equal length; see alsor tile function below
x, reps = node.inputs
shp = x.shape
tiled_shp = shp * reps
out_shape = []
for i in range(self.ndim):
out_shape.append(tiled_shp[i])
return [out_shape]
def grad(self, inp, grads):
x, reps = inp
g_out, = grads
......@@ -5539,21 +5550,24 @@ def tile(x, reps, ndim=None):
Tile input array `x` according to `reps`. See the docstring of `numpy.tile`
for details.
Currently, `reps` must be a constant.
Currently, `reps` must be a constant, x.ndim and len(reps) must be equal
and, if specified, 'ndim' must be equal to both.
TODO: expand this.
"""
if len(reps) != x.ndim:
if isinstance(reps, theano.tensor.TensorVariable):
raise ValueError("'reps' argument to 'tile' must be a constant (e.g. "
"tuple, list of integers)")
elif len(reps) != x.ndim:
raise ValueError("len(reps) != x.ndim not currently supported")
elif (ndim is not None) and ndim != x.ndim:
raise ValueError("if specified, ndim must be equal to both x.ndim and "
"len(reps)")
if not hasattr(tile, 'op'):
tile.op = {}
try:
assert python_all([int(i) == i for i in iter(reps)])
except (TypeError, AssertionError):
raise ValueError("reps argument to tile must be a constant (e.g. "
"tuple, list of integers)")
if ndim is None:
ndim = len(reps)
......
......@@ -37,7 +37,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
itensor3)
itensor3, Tile)
from theano.tests import unittest_tools as utt
from theano.printing import debugprint
......@@ -6625,6 +6625,31 @@ shapes generated:
[Reshape(ndim, '_name')(admat, aivec)],
[admat_val, [4, 3]], Reshape)
# Tile: basic 5292
advec = dvector()
advec_val = rand(5)
aivec_val = [3]
ndim = 1
self._compile_and_check([advec],
[tile(advec, aivec_val, ndim)],
[advec_val], Tile)
admat = dmatrix()
admat_val = rand(2, 4)
aivec_val = [2, 3]
ndim = None
self._compile_and_check([admat],
[tile(admat, aivec_val)],
[admat_val], Tile)
adtens4 = dtensor4()
adtens4_val = rand(2, 4, 3, 5)
aivec_val = [2, 3, 1, 4]
ndim = 4
self._compile_and_check([adtens4],
[tile(adtens4, aivec_val, ndim)],
[adtens4_val], Tile)
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论