提交 bb912c0d authored 作者: Amjad Almahairi's avatar Amjad Almahairi

commeting out tile optimization test

上级 717bd64b
...@@ -3401,39 +3401,40 @@ def test_local_mul_specialize(): ...@@ -3401,39 +3401,40 @@ def test_local_mul_specialize():
nodes = [node.op for node in f.maker.fgraph.toposort()] nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [T.mul] assert nodes == [T.mul]
# Tile op is deprecated
class T_Tile(unittest.TestCase):
def test_local_useless_tile(self): # class T_Tile(unittest.TestCase):
v = T.vector() # def test_local_useless_tile(self):
m = T.matrix() # v = T.vector()
mode = None # m = T.matrix()
if theano.config.mode == "FAST_COMPILE": # mode = None
mode = "FAST_RUN" # if theano.config.mode == "FAST_COMPILE":
for var, data in [(v, [1, 2, 3]), (m, [[1, 2], [3, 4]])]: # mode = "FAST_RUN"
# Currently, only a repeat patter == ndim is supported. # for var, data in [(v, [1, 2, 3]), (m, [[1, 2], [3, 4]])]:
for ndim in [var.ndim]: # range(1, var.ndim): # # Currently, only a repeat patter == ndim is supported.
f = theano.function([var], T.tile(var, (1,)*ndim), mode=mode) # for ndim in [var.ndim]: # range(1, var.ndim):
topo = f.maker.fgraph.toposort() # f = theano.function([var], T.tile(var, (1,)*ndim), mode=mode)
assert len(topo) == 1 # topo = f.maker.fgraph.toposort()
assert isinstance(topo[0].op, compile.DeepCopyOp) # assert len(topo) == 1
f(data) # assert isinstance(topo[0].op, compile.DeepCopyOp)
# f(data)
# If the repeat parameter is longer then v.ndim, we must
# replace it with a DimShuffle to add the extra parameter. # # If the repeat parameter is longer then v.ndim, we must
# But it isn't supported for now, so assert that we raise an # # replace it with a DimShuffle to add the extra parameter.
# error. # # But it isn't supported for now, so assert that we raise an
self.assertRaises(ValueError, T.tile, v, (1,)*(v.ndim+1)) # # error.
# If the repeat parameter is shorter then m.ndim, it should # self.assertRaises(ValueError, T.tile, v, (1,)*(v.ndim+1))
# pad tot he left the repeat patter with 1. It is not supported for now. # # If the repeat parameter is shorter then m.ndim, it should
#f = theano.function([var], T.tile(v, (1,)*(v.ndim+1))) # # pad tot he left the repeat patter with 1. It is not supported for now.
#topo = f.maker.fgraph.toposort() # #f = theano.function([var], T.tile(v, (1,)*(v.ndim+1)))
#assert len(topo) == 1 # #topo = f.maker.fgraph.toposort()
#assert isinstance(topo[0].op, DimShuffe) # #assert len(topo) == 1
self.assertRaises(ValueError, T.tile, m, (1,)*(m.ndim-1)) # #assert isinstance(topo[0].op, DimShuffe)
#f = theano.function([var], T.tile(m, (1,)*(m.ndim-1))) # self.assertRaises(ValueError, T.tile, m, (1,)*(m.ndim-1))
#topo = f.maker.fgraph.toposort() # #f = theano.function([var], T.tile(m, (1,)*(m.ndim-1)))
#assert len(topo) == 1 # #topo = f.maker.fgraph.toposort()
#assert isinstance(topo[0].op, compile.DeepCopyOp) # #assert len(topo) == 1
# #assert isinstance(topo[0].op, compile.DeepCopyOp)
def speed_local_pow_specialize_range(): def speed_local_pow_specialize_range():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论