提交 490dd097 authored 作者: carriepl's avatar carriepl

Merge pull request #3069 from aalmah/ticket_2377

added optimization for dividing by one
......@@ -2429,7 +2429,6 @@ def local_subtensor_merge(node):
lambda x: isinstance(x, T.Variable))
# Do not call make_node for test_value
out = subtens(x, *sl_ins)
return [out]
......@@ -4399,6 +4398,17 @@ def local_mul_to_sqr(node):
return [T.sqr(node.inputs[0])]
@register_canonicalize
@gof.local_optimizer([T.int_div, T.floor_div])
def local_intdiv_by_one(node):
"""x // 1 -> x
"""
if node.op in [T.int_div]:
if isinstance(node.inputs[1], T.TensorConstant) and \
numpy.all(node.inputs[1].value == 1):
return [node.inputs[0]]
@gof.local_optimizer([T.pow])
def local_pow_specialize(node):
# here, we are past the point of canonicalization, so we don't want
......
......@@ -5393,6 +5393,48 @@ def test_assert_op_gradient():
assert func(x_val) == 1
class TestIntDivByOne(unittest.TestCase):
def setUp(self):
self.mode = theano.compile.mode.get_default_mode()
self.mode = self.mode.including('local_intdiv_by_one')
def test1(self):
"""Tests removing the extra floor_div by 1 introduced by
local_subtensor_merge optimization"""
y = T.tensor4('y')
self.mode = self.mode.excluding('fusion')
f = theano.function([y], y[::-1][::-1], mode=self.mode)
graph = f.maker.fgraph.toposort()
divs = [node for node in graph
if isinstance(node.op, T.elemwise.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.IntDiv)]
assert len(divs) == 0
def test2(self):
"""Simple test case for removing dividing by 1"""
y = T.tensor4('y')
z = y // 1
f = theano.function([y], z, mode = self.mode)
graph = f.maker.fgraph.toposort()
divs = [node for node in graph
if isinstance(node.op, T.elemwise.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.IntDiv)]
assert len(divs) == 0
def test3(self):
"""Simple test case for removing dividing by a tensor of ones"""
y = T.tensor4('y')
z = y // numpy.ones((2,2,2,2))
f = theano.function([y], z, mode=self.mode)
graph = f.maker.fgraph.toposort()
divs = [node for node in graph
if isinstance(node.op, T.elemwise.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.IntDiv)]
assert len(divs) == 0
if __name__ == '__main__':
t = TestMakeVector('setUp')
t.setUp()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论