提交 23817993 authored 作者: Frederic Bastien's avatar Frederic Bastien

review commit. Remove debug print. Test a missed case.

上级 2d4b514d
......@@ -1135,11 +1135,9 @@ def local_subtensor_merge(node):
u_start = u.owner.op.idx_list[0].start
if len(u.owner.inputs) == 1 and isinstance(u_start, int):
print 'int'
start0 = u_start
elif (len(u.owner.inputs) == 2 and
isinstance (u_start, scalar.basic.Scalar)):
print 'scalar'
start0 = T.tensor_from_scalar(u.owner.inputs[1])
else:
return False
......
......@@ -1249,14 +1249,14 @@ class test_local_subtensor_merge(unittest.TestCase):
def test_const(self):
# var[const::][-1] -> var[-1]
x = TT.matrix('x')
for idx in [-2,-1,1,2]:
for idx in range(-5,4):
f = function([x], x[idx::][-1], mode=mode_opt)
#theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort()
print [t for t in topo if isinstance(t.op, TT.Subtensor)]
#print [t for t in topo if isinstance(t.op, TT.Subtensor)]
assert len([t for t in topo if isinstance(t.op, TT.Subtensor)]) == 1
print topo[-1].op
#print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
......@@ -1276,15 +1276,15 @@ class test_local_subtensor_merge(unittest.TestCase):
#theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort()
print [t for t in topo if isinstance(t.op, TT.Subtensor)]
#print [t for t in topo if isinstance(t.op, TT.Subtensor)]
assert len([t for t in topo if isinstance(t.op, TT.Subtensor)]) == 1
print topo[-1].op
#print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
for idx in range(-10,2):
f(x_val, idx) # let debugmode test something
for idx in range(2,4):
for idx in range(2,5):
self.assertRaises(IndexError, f, x_val, idx)
def test_dont_opt(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论