提交 addbc07c authored 作者: Frederic Bastien's avatar Frederic Bastien

small refactoring.

上级 20bc7320
...@@ -1266,6 +1266,7 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1266,6 +1266,7 @@ class test_local_subtensor_merge(unittest.TestCase):
def test_const(self): def test_const(self):
# var[const::][-1] -> var[-1] # var[const::][-1] -> var[-1]
x = TT.matrix('x') x = TT.matrix('x')
x_val = [[0,1],[2,3]]
for idx in range(-5,4): for idx in range(-5,4):
f = function([x], x[idx::][-1], mode=mode_opt) f = function([x], x[idx::][-1], mode=mode_opt)
...@@ -1276,7 +1277,6 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1276,7 +1277,6 @@ class test_local_subtensor_merge(unittest.TestCase):
#print topo[-1].op #print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp) assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
if idx<2: if idx<2:
# The first subtensor is non-empty, so it makes sense # The first subtensor is non-empty, so it makes sense
f(x_val) # let debugmode test something f(x_val) # let debugmode test something
...@@ -1325,6 +1325,7 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1325,6 +1325,7 @@ class test_local_subtensor_merge(unittest.TestCase):
def test_const2(self): def test_const2(self):
# var[::-1][const] -> var[-1] # var[::-1][const] -> var[-1]
x = TT.matrix('x') x = TT.matrix('x')
x_val = [[0,1],[2,3]]
for idx in range(-5,4): for idx in range(-5,4):
f = function([x], x[::-1][idx], mode=mode_opt) f = function([x], x[::-1][idx], mode=mode_opt)
...@@ -1335,7 +1336,6 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1335,7 +1336,6 @@ class test_local_subtensor_merge(unittest.TestCase):
#print topo[-1].op #print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp) assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
if idx<2 and idx>=-2: if idx<2 and idx>=-2:
# The first subtensor is non-empty, so it makes sense # The first subtensor is non-empty, so it makes sense
f(x_val) # let debugmode test something f(x_val) # let debugmode test something
...@@ -1384,6 +1384,7 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1384,6 +1384,7 @@ class test_local_subtensor_merge(unittest.TestCase):
def test_const3(self): def test_const3(self):
# var[::-1][:const] -> var[-1] # var[::-1][:const] -> var[-1]
x = TT.matrix('x') x = TT.matrix('x')
x_val = [[0,1],[2,3]]
for idx in range(-5,4): for idx in range(-5,4):
f = function([x], x[::-1][:idx], mode=mode_opt) f = function([x], x[::-1][:idx], mode=mode_opt)
...@@ -1394,7 +1395,6 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1394,7 +1395,6 @@ class test_local_subtensor_merge(unittest.TestCase):
#print topo[-1].op #print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp) assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
f(x_val) # let debugmode test something f(x_val) # let debugmode test something
def test_scalar3(self): def test_scalar3(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论