提交 732f8bcb authored 作者: David Warde-Farley's avatar David Warde-Farley

Add tests for tile and (future) test for TileGrad.

上级 f8012657
......@@ -33,7 +33,8 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
var, value, Join, shape, MaxAndArgmax, lscalar, zvector, exp,
get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll)
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile)
from theano.tests import unittest_tools as utt
......@@ -3984,11 +3985,42 @@ def test_flatten_outdim_invalid():
except ValueError:
pass
# TODO: write test case for Tile Op
# See Ticket #619
#def test_tile():
# print >> sys.stderr, "WARNING: No testcase for Tile"
# pass
def test_tile():
# Test the one-dimensional case.
rng = numpy.random.RandomState(utt.fetch_seed())
x = vector()
f = function([x], tile(x, (2,)))
x_ = rng.randn(5)
assert numpy.all(f(x_) == numpy.tile(x_, (2,)))
# Test the two-dimensional case.
x = matrix()
f = function([x], tile(x, (2, 3)))
x_ = rng.randn(2, 4)
assert numpy.all(f(x_) == numpy.tile(x_, (2, 3)))
# Test the three-dimensional case.
x = tensor3()
f = function([x], tile(x, (2, 3, 4)))
x_ = rng.randn(2, 4, 3)
assert numpy.all(f(x_) == numpy.tile(x_, (2, 3, 4)))
# XXX: It turns out that almost no cases of the tile gradient actually work.
# This is a test that should pass if the proper implementation is filled in.
def test_tile_grad_3d():
raise SkipTest() # Remove me when this is implemented.
rng = numpy.random.RandomState(utt.fetch_seed())
w = rng.randn(3, 4, 2)
w_tiled = numpy.tile(w, (2, 3, 4))
x = tensor.tensor3()
c = (as_tensor_variable(w_tiled) * tile(x, (2, 3, 4))).sum()
f = function([x], grad(c, x))
x_ = rng.randn(3, 4, 2)
# The gradient should be w, multiplied by its tiling dimensions (since
# the gradients are additive through the tiling operation)
assert numpy.all(f(x_) == 2 * 3 * 4 * w)
class TestARange(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论