提交 2e856e90 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

new fixes and added a test

上级 2ec5abc5
......@@ -4397,6 +4397,7 @@ def local_mul_to_sqr(node):
if node.inputs[0] is node.inputs[1]:
return [T.sqr(node.inputs[0])]
@register_canonicalize
@gof.local_optimizer([T.int_div, T.floor_div])
def local_intdiv_by_one(node):
......@@ -4404,7 +4405,7 @@ def local_intdiv_by_one(node):
"""
if node.op in [T.int_div]:
if isinstance(node.inputs[1], T.TensorConstant) and \
node.inputs[1].value == 1:
numpy.all(node.inputs[1].value == 1):
return [node.inputs[0]]
......
......@@ -5393,28 +5393,37 @@ def test_assert_op_gradient():
assert func(x_val) == 1
class TestDivByOne(unittest.TestCase):
def test1(self):
y = T.tensor4('y')
class TestIntDivByOne(unittest.TestCase):
mode = theano.compile.mode.get_default_mode()
mode_wo_fusion = mode.excluding('fusion')
def setUp(self):
self.mode = theano.compile.mode.get_default_mode()
self.mode = self.mode.including('local_intdiv_by_one')
f = theano.function([y], y[::-1][::-1], mode=mode_wo_fusion)
def test1(self):
y = T.tensor4('y')
self.mode = self.mode.excluding('fusion')
f = theano.function([y], y[::-1][::-1], mode=self.mode)
graph = f.maker.fgraph.toposort()
divs = [node for node in graph
if isinstance(node.op, T.elemwise.Elemwise) and
divs = [node for node in graph if isinstance(node.op, T.elemwise.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.IntDiv)]
assert len(divs) == 0
def test2(self):
y = T.tensor4('y')
z = y // 1
f = theano.function([y], z)
f = theano.function([y], z, mode = self.mode)
graph = f.maker.fgraph.toposort()
divs = [node for node in graph if isinstance(node.op, T.elemwise.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.IntDiv)]
assert len(divs) == 0
def test3(self):
y = T.tensor4('y')
z = y // numpy.ones((2,2,2,2))
f = theano.function([y], z, mode=self.mode)
graph = f.maker.fgraph.toposort()
divs = [node for node in graph
if isinstance(node.op, T.elemwise.Elemwise) and
divs = [node for node in graph if isinstance(node.op, T.elemwise.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.IntDiv)]
assert len(divs) == 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论