提交 a9553206 authored 作者: abergeron's avatar abergeron

Merge pull request #1965 from nouiz/tile

Opt that remove useless tile
...@@ -3999,6 +3999,9 @@ class Tile(Op): ...@@ -3999,6 +3999,9 @@ class Tile(Op):
def __hash__(self): def __hash__(self):
return hash(Tile) ^ hash(self.ndim) return hash(Tile) ^ hash(self.ndim)
def __str__(self):
return self.__class__.__name__ + "{ndim=%d}" % self.ndim
def make_node(self, x, reps): def make_node(self, x, reps):
x = as_tensor_variable(x) x = as_tensor_variable(x)
reps = as_tensor_variable(reps) reps = as_tensor_variable(reps)
......
...@@ -2463,6 +2463,44 @@ def local_div_switch_sink(node): ...@@ -2463,6 +2463,44 @@ def local_div_switch_sink(node):
return False return False
#############
# Tile Opts #
#############
@register_canonicalize
@register_stabilize
@gof.local_optimizer([T.Tile])
def local_useless_tile(node):
"""Tile(x, (1,)*N) -> x
This is useless tile. (1,)*N, just mean a vector with all element
being 1.
"""
if isinstance(node.op, T.Tile):
try:
a = T.get_scalar_constant_value(node.inputs[1])
if a == 1:
try:
l = T.get_vector_length(node.inputs[1])
if l == node.inputs[0].ndim:
return [node.inputs[0]]
elif l < node.inputs[0].ndim:
# The Op don't support that case, so we can't
# implement the opt and test it.
return
return [node.inputs[0]]
else:
# The Op don't support that case, so we can't
# implement the opt and test it.
return
x_nd = node.inputs[0].ndim
broad = ['x'] * (l - x_nd) + range(x_nd)
return [node.inputs[0].dimshuffle(broad)]
except ValueError:
return
except NotScalarConstantError:
return
################ ################
# Flatten Opts # # Flatten Opts #
################ ################
......
...@@ -2883,6 +2883,36 @@ def test_local_mul_specialize(): ...@@ -2883,6 +2883,36 @@ def test_local_mul_specialize():
assert nodes == [T.mul] assert nodes == [T.mul]
class T_Tile(unittest.TestCase):
def test_local_useless_tile(self):
v = T.vector()
m = T.matrix()
for var, data in [(v, [1, 2, 3]), (m, [[1, 2], [3, 4]])]:
# Currently, only a repeat patter == ndim is supported.
for ndim in [var.ndim]: # range(1, var.ndim):
f = theano.function([var], T.tile(var, (1,)*ndim))
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, compile.DeepCopyOp)
# If the repeat parameter is longer then v.ndim, we must
# replace it with a DimShuffle to add the extra parameter.
# But it isn't supported for now, so assert that we raise an
# error.
self.assertRaises(ValueError, T.tile, v, (1,)*(v.ndim+1))
# If the repeat parameter is shorter then m.ndim, it should
# pad tot he left the repeat patter with 1. It is not supported for now.
#f = theano.function([var], T.tile(v, (1,)*(v.ndim+1)))
#topo = f.maker.fgraph.toposort()
#assert len(topo) == 1
#assert isinstance(topo[0].op, DimShuffe)
self.assertRaises(ValueError, T.tile, m, (1,)*(m.ndim-1))
#f = theano.function([var], T.tile(m, (1,)*(m.ndim-1)))
#topo = f.maker.fgraph.toposort()
#assert len(topo) == 1
#assert isinstance(topo[0].op, compile.DeepCopyOp)
def speed_local_pow_specialize_range(): def speed_local_pow_specialize_range():
val = numpy.random.rand(1e7) val = numpy.random.rand(1e7)
v = T.vector() v = T.vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论