提交 bc7018bc authored 作者: Amjad Almahairi's avatar Amjad Almahairi

comments to tests

上级 2e856e90
...@@ -5400,30 +5400,37 @@ class TestIntDivByOne(unittest.TestCase): ...@@ -5400,30 +5400,37 @@ class TestIntDivByOne(unittest.TestCase):
self.mode = self.mode.including('local_intdiv_by_one') self.mode = self.mode.including('local_intdiv_by_one')
def test1(self): def test1(self):
"""Tests removing the extra floor_div by 1 introduced by
local_subtensor_merge optimization"""
y = T.tensor4('y') y = T.tensor4('y')
self.mode = self.mode.excluding('fusion') self.mode = self.mode.excluding('fusion')
f = theano.function([y], y[::-1][::-1], mode=self.mode) f = theano.function([y], y[::-1][::-1], mode=self.mode)
graph = f.maker.fgraph.toposort() graph = f.maker.fgraph.toposort()
divs = [node for node in graph if isinstance(node.op, T.elemwise.Elemwise) and divs = [node for node in graph
if isinstance(node.op, T.elemwise.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.IntDiv)] isinstance(node.op.scalar_op, theano.scalar.IntDiv)]
assert len(divs) == 0 assert len(divs) == 0
def test2(self): def test2(self):
"""Simple test case for removing dividing by 1"""
y = T.tensor4('y') y = T.tensor4('y')
z = y // 1 z = y // 1
f = theano.function([y], z, mode = self.mode) f = theano.function([y], z, mode = self.mode)
graph = f.maker.fgraph.toposort() graph = f.maker.fgraph.toposort()
divs = [node for node in graph if isinstance(node.op, T.elemwise.Elemwise) and divs = [node for node in graph
if isinstance(node.op, T.elemwise.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.IntDiv)] isinstance(node.op.scalar_op, theano.scalar.IntDiv)]
assert len(divs) == 0 assert len(divs) == 0
def test3(self): def test3(self):
"""Simple test case for removing dividing by a tensor of ones"""
y = T.tensor4('y') y = T.tensor4('y')
z = y // numpy.ones((2,2,2,2)) z = y // numpy.ones((2,2,2,2))
f = theano.function([y], z, mode=self.mode) f = theano.function([y], z, mode=self.mode)
graph = f.maker.fgraph.toposort() graph = f.maker.fgraph.toposort()
divs = [node for node in graph if isinstance(node.op, T.elemwise.Elemwise) and divs = [node for node in graph
if isinstance(node.op, T.elemwise.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.IntDiv)] isinstance(node.op.scalar_op, theano.scalar.IntDiv)]
assert len(divs) == 0 assert len(divs) == 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论