提交 8030ffa1 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

putting back the optimization test for useless tile, but now on the Tile op directly

上级 35bf93a3
...@@ -51,6 +51,7 @@ from theano.tensor import ( ...@@ -51,6 +51,7 @@ from theano.tensor import (
join, join,
Subtensor, Subtensor,
TensorType, TensorType,
Tile,
) )
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -3401,40 +3402,43 @@ def test_local_mul_specialize(): ...@@ -3401,40 +3402,43 @@ def test_local_mul_specialize():
nodes = [node.op for node in f.maker.fgraph.toposort()] nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [T.mul] assert nodes == [T.mul]
# Tile op is deprecated
class T_Tile(unittest.TestCase):
# class T_Tile(unittest.TestCase): def test_local_useless_tile(self):
# def test_local_useless_tile(self): # Tile op is deprecated and tile function no more uses it
# v = T.vector() # we'll test the op directly
# m = T.matrix() v = T.vector()
# mode = None m = T.matrix()
# if theano.config.mode == "FAST_COMPILE": mode = None
# mode = "FAST_RUN" if theano.config.mode == "FAST_COMPILE":
# for var, data in [(v, [1, 2, 3]), (m, [[1, 2], [3, 4]])]: mode = "FAST_RUN"
# # Currently, only a repeat patter == ndim is supported. for var, data in [(v, [1, 2, 3]), (m, [[1, 2], [3, 4]])]:
# for ndim in [var.ndim]: # range(1, var.ndim): # Currently, only a repeat patter == ndim is supported.
# f = theano.function([var], T.tile(var, (1,)*ndim), mode=mode) for ndim in [var.ndim]: # range(1, var.ndim):
# topo = f.maker.fgraph.toposort() f = theano.function([var], Tile(ndim)(var, (1,)*ndim), mode=mode)
# assert len(topo) == 1 topo = f.maker.fgraph.toposort()
# assert isinstance(topo[0].op, compile.DeepCopyOp) assert len(topo) == 1
# f(data) assert isinstance(topo[0].op, compile.DeepCopyOp)
f(data)
# # If the repeat parameter is longer then v.ndim, we must
# # replace it with a DimShuffle to add the extra parameter. # If the repeat parameter is longer then v.ndim, we must
# # But it isn't supported for now, so assert that we raise an # replace it with a DimShuffle to add the extra parameter.
# # error. # But it isn't supported for now, so assert that we raise an
# self.assertRaises(ValueError, T.tile, v, (1,)*(v.ndim+1)) # error.
# # If the repeat parameter is shorter then m.ndim, it should
# # pad tot he left the repeat patter with 1. It is not supported for now. self.assertRaises(ValueError, T.tile, v, (1,)*(v.ndim+1))
# #f = theano.function([var], T.tile(v, (1,)*(v.ndim+1))) # If the repeat parameter is shorter then m.ndim, it should
# #topo = f.maker.fgraph.toposort() # pad tot he left the repeat patter with 1. It is not supported for now.
# #assert len(topo) == 1 #f = theano.function([var], T.tile(v, (1,)*(v.ndim+1)))
# #assert isinstance(topo[0].op, DimShuffe) #topo = f.maker.fgraph.toposort()
# self.assertRaises(ValueError, T.tile, m, (1,)*(m.ndim-1)) #assert len(topo) == 1
# #f = theano.function([var], T.tile(m, (1,)*(m.ndim-1))) #assert isinstance(topo[0].op, DimShuffe)
# #topo = f.maker.fgraph.toposort()
# #assert len(topo) == 1 self.assertRaises(ValueError, T.tile, m, (1,)*(m.ndim-1))
# #assert isinstance(topo[0].op, compile.DeepCopyOp) #f = theano.function([var], T.tile(m, (1,)*(m.ndim-1)))
#topo = f.maker.fgraph.toposort()
#assert len(topo) == 1
#assert isinstance(topo[0].op, compile.DeepCopyOp)
def speed_local_pow_specialize_range(): def speed_local_pow_specialize_range():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论