提交 b4319a6b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update test for useless tile to use the tile function

Now it should work.
上级 3262c424
...@@ -62,7 +62,6 @@ from theano.tensor import ( ...@@ -62,7 +62,6 @@ from theano.tensor import (
join, join,
Subtensor, Subtensor,
TensorType, TensorType,
Tile,
tile tile
) )
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
...@@ -4219,23 +4218,31 @@ def test_local_mul_specialize(): ...@@ -4219,23 +4218,31 @@ def test_local_mul_specialize():
class T_Tile(unittest.TestCase): class T_Tile(unittest.TestCase):
def test_local_useless_tile(self): def test_local_useless_tile(self):
# Tile op is deprecated so the tile function doesn't use it
# anymore, we'll test here the op directly
v = T.vector() v = T.vector()
m = T.matrix() m = T.matrix()
mode = None mode = None
if theano.config.mode == "FAST_COMPILE": if theano.config.mode == "FAST_COMPILE":
mode = "FAST_RUN" mode = "FAST_RUN"
for var, data in [(v, [1, 2, 3]), (m, [[1, 2], [3, 4]])]: for var, data in [(v, [1, 2, 3]), (m, [[1, 2], [3, 4]])]:
# Currently, only a repeat patter == ndim is supported. # When len(repeat pattern) <= var.ndim, everything is removed
for ndim in [var.ndim]: # range(1, var.ndim): # for ndim in range(1, var.ndim):
f = theano.function([var], Tile(ndim)(var, (1,)*ndim), mode=mode) for ndim in range(var.ndim + 1):
f = theano.function([var], tile(var, (1,) * ndim), mode=mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert isinstance(topo[0].op, compile.DeepCopyOp) assert isinstance(topo[0].op, compile.DeepCopyOp)
f(data) f(data)
# In this case the opt only removes nodes, # In this case the opt only removes nodes,
# no need to check_stack_trace # no need to check_stack_trace
# When len(repeat pattern) > var.ndim, only a dimshuffle should be
# left, but there can be a DeepCopy as well
for ndim in range(var.ndim + 1, var.ndim + 3):
f = theano.function([var], tile(var, (1,) * ndim), mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2
assert isinstance(topo[0].op, DimShuffle)
assert check_stack_trace(f, ops_to_check=[DimShuffle])
f(data)
def speed_local_pow_specialize_range(): def speed_local_pow_specialize_range():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论