提交 1753068c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added fix suggested by Pascal for the tests when dealing with the exception…

Added fix suggested by Pascal for the tests when dealing with the exception thrown by Subtensor (e_invalid)
上级 d98a89d3
...@@ -2319,7 +2319,10 @@ class Subtensor(Op): ...@@ -2319,7 +2319,10 @@ class Subtensor(Op):
idx_list = list(self.idx_list) idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim: if len(idx_list) > x.type.ndim:
raise ValueError(Subtensor.e_invalid%(len(idx_list), x.type.ndim)) exception = ValueError(Subtensor.e_invalid%(len(idx_list),
x.type.ndim))
exception.subtensor_invalid = True
raise exception
#infer the broadcasting pattern #infer the broadcasting pattern
padded = idx_list + [slice(0,sys.maxint,1)] * (x.type.ndim - len(idx_list)) padded = idx_list + [slice(0,sys.maxint,1)] * (x.type.ndim - len(idx_list))
...@@ -2598,7 +2601,10 @@ class IncSubtensor(Op): ...@@ -2598,7 +2601,10 @@ class IncSubtensor(Op):
idx_list = list(self.idx_list) idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim: if len(idx_list) > x.type.ndim:
raise ValueError(Subtensor.e_invalid%(len(idx_list), x.type.ndim)) exception = ValueError(Subtensor.e_invalid%(len(idx_list),
x.type.ndim))
exception.subtensor_invalid = True
raise exception
#infer the broadcasting pattern #infer the broadcasting pattern
padded = idx_list + [slice(0,sys.maxint,1)] * (x.type.ndim - len(idx_list)) padded = idx_list + [slice(0,sys.maxint,1)] * (x.type.ndim - len(idx_list))
......
...@@ -647,7 +647,7 @@ TanhInplaceTester = makeBroadcastTester(op = inplace.tanh_inplace, ...@@ -647,7 +647,7 @@ TanhInplaceTester = makeBroadcastTester(op = inplace.tanh_inplace,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
inplace = True) inplace = True)
#inplace ops when the input is integer and the output is float* #inplace ops when the input is integer and the output is float*
# don't have a well defined behavior. We don't test that case. # don't have a well defined behavior. We don't test that case.
_good_broadcast_unary_normal_no_int = _good_broadcast_unary_normal.copy() _good_broadcast_unary_normal_no_int = _good_broadcast_unary_normal.copy()
del _good_broadcast_unary_normal_no_int['integers'] del _good_broadcast_unary_normal_no_int['integers']
...@@ -903,7 +903,7 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -903,7 +903,7 @@ class T_max_and_argmax(unittest.TestCase):
def test_grad(self): def test_grad(self):
data = numpy.random.rand(2,3) data = numpy.random.rand(2,3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
def check_grad_max(data, max_grad_data, axis=None): def check_grad_max(data, max_grad_data, axis=None):
#This work only for axis in [0,None] #This work only for axis in [0,None]
assert axis in [0,None] assert axis in [0,None]
...@@ -915,7 +915,7 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -915,7 +915,7 @@ class T_max_and_argmax(unittest.TestCase):
else: else:
for id,v in enumerate(argmax): for id,v in enumerate(argmax):
z[v*numpy.prod(data.shape[data.ndim-1:axis:-1])+id]+=1 z[v*numpy.prod(data.shape[data.ndim-1:axis:-1])+id]+=1
z = z.reshape(data.shape) z = z.reshape(data.shape)
assert numpy.all(max_grad_data == z) assert numpy.all(max_grad_data == z)
...@@ -1053,7 +1053,7 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -1053,7 +1053,7 @@ class T_argmin_argmax(unittest.TestCase):
def test_grad_argmin(self): def test_grad_argmin(self):
data = numpy.random.rand(2,3) data = numpy.random.rand(2,3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
#test grad of argmin #test grad of argmin
utt.verify_grad(lambda v: argmin(v), [data]) utt.verify_grad(lambda v: argmin(v), [data])
...@@ -1072,7 +1072,7 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -1072,7 +1072,7 @@ class T_argmin_argmax(unittest.TestCase):
def test_grad_argmax(self): def test_grad_argmax(self):
data = numpy.random.rand(2,3) data = numpy.random.rand(2,3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
#test grad of argmax #test grad of argmax
utt.verify_grad(lambda v: argmax(v), [data]) utt.verify_grad(lambda v: argmax(v), [data])
...@@ -1172,7 +1172,7 @@ class T_min_max(unittest.TestCase): ...@@ -1172,7 +1172,7 @@ class T_min_max(unittest.TestCase):
v = eval_outputs(fct(n,-2)) v = eval_outputs(fct(n,-2))
self.failUnless(v.shape == (3,)) self.failUnless(v.shape == (3,))
self.failUnless(numpy.all(v == nfct(n.value,-2))) self.failUnless(numpy.all(v == nfct(n.value,-2)))
v = eval_outputs(fct(n,-1).shape) v = eval_outputs(fct(n,-1).shape)
assert v==(2) assert v==(2)
v = eval_outputs(fct(n,-2).shape) v = eval_outputs(fct(n,-2).shape)
...@@ -1220,7 +1220,7 @@ class T_min_max(unittest.TestCase): ...@@ -1220,7 +1220,7 @@ class T_min_max(unittest.TestCase):
def test_grad_max(self): def test_grad_max(self):
data = numpy.random.rand(2,3) data = numpy.random.rand(2,3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
def check_grad_max(data, max_grad_data, axis=None): def check_grad_max(data, max_grad_data, axis=None):
#This work only for axis in [0,None] #This work only for axis in [0,None]
assert axis in [0,None] assert axis in [0,None]
...@@ -1232,7 +1232,7 @@ class T_min_max(unittest.TestCase): ...@@ -1232,7 +1232,7 @@ class T_min_max(unittest.TestCase):
else: else:
for id,v in enumerate(argmax): for id,v in enumerate(argmax):
z[v*numpy.prod(data.shape[data.ndim-1:axis:-1])+id]+=1 z[v*numpy.prod(data.shape[data.ndim-1:axis:-1])+id]+=1
z = z.reshape(data.shape) z = z.reshape(data.shape)
assert numpy.all(max_grad_data == z) assert numpy.all(max_grad_data == z)
...@@ -1252,7 +1252,7 @@ class T_min_max(unittest.TestCase): ...@@ -1252,7 +1252,7 @@ class T_min_max(unittest.TestCase):
def test_grad_min(self): def test_grad_min(self):
data = numpy.random.rand(2,3) data = numpy.random.rand(2,3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
def check_grad_min(data, min_grad_data, axis=None): def check_grad_min(data, min_grad_data, axis=None):
#This work only for axis in [0,None] #This work only for axis in [0,None]
assert axis in [0,None] assert axis in [0,None]
...@@ -1264,7 +1264,7 @@ class T_min_max(unittest.TestCase): ...@@ -1264,7 +1264,7 @@ class T_min_max(unittest.TestCase):
else: else:
for id,v in enumerate(argmin): for id,v in enumerate(argmin):
z[v*numpy.prod(data.shape[data.ndim-1:axis:-1])+id]+=1 z[v*numpy.prod(data.shape[data.ndim-1:axis:-1])+id]+=1
z = z.reshape(data.shape) z = z.reshape(data.shape)
assert numpy.all(min_grad_data == z) assert numpy.all(min_grad_data == z)
...@@ -1304,7 +1304,7 @@ class T_subtensor(unittest.TestCase): ...@@ -1304,7 +1304,7 @@ class T_subtensor(unittest.TestCase):
try: try:
t = n[0] t = n[0]
except ValueError, e: except ValueError, e:
self.failUnless(e[0] is Subtensor.e_invalid) self.failUnless(hasattr(e,'subtensor_invalid'))
return return
self.fail() self.fail()
...@@ -1356,7 +1356,7 @@ class T_subtensor(unittest.TestCase): ...@@ -1356,7 +1356,7 @@ class T_subtensor(unittest.TestCase):
try: try:
t = n[0,0] t = n[0,0]
except ValueError, e: except ValueError, e:
self.failUnless(e[0] is Subtensor.e_invalid) self.failUnless(hasattr(e,'subtensor_invalid'))
return return
self.fail() self.fail()
def test1_ok_elem(self): def test1_ok_elem(self):
...@@ -2561,7 +2561,7 @@ def test_flatten_outdim_invalid(): ...@@ -2561,7 +2561,7 @@ def test_flatten_outdim_invalid():
# TODO: write test case for Tile Op # TODO: write test case for Tile Op
def test_tile(): def test_tile():
print >> sys.stderr, "WARNING: No testcase for Tile" print >> sys.stderr, "WARNING: No testcase for Tile"
pass pass
class TestARange(unittest.TestCase): class TestARange(unittest.TestCase):
...@@ -2724,7 +2724,7 @@ class TestARange(unittest.TestCase): ...@@ -2724,7 +2724,7 @@ class TestARange(unittest.TestCase):
f = function([stop], out.shape, mode=mode) f = function([stop], out.shape, mode=mode)
assert len(f.maker.env.toposort())==2 assert len(f.maker.env.toposort())==2
#[Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)] #[Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)]
assert out.dtype == start.type.dtype assert out.dtype == start.type.dtype
assert numpy.all(f(5) == len(numpy.arange(0,5))) assert numpy.all(f(5) == len(numpy.arange(0,5)))
assert numpy.all(f(11) == len(numpy.arange(0,11))) assert numpy.all(f(11) == len(numpy.arange(0,11)))
...@@ -2961,7 +2961,7 @@ class test_tensordot(unittest.TestCase): ...@@ -2961,7 +2961,7 @@ class test_tensordot(unittest.TestCase):
self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes), self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes),
f5(aval,bval))) f5(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval]) utt.verify_grad(TensorDot(axes), [aval,bval])
axes = (axes[1],axes[0]) axes = (axes[1],axes[0])
c = tensordot(axes)(btens, atens) c = tensordot(axes)(btens, atens)
f6 = inplace_func([btens,atens],c) f6 = inplace_func([btens,atens],c)
...@@ -3051,7 +3051,7 @@ class test_tensordot(unittest.TestCase): ...@@ -3051,7 +3051,7 @@ class test_tensordot(unittest.TestCase):
def test_tensordot_grad(self): def test_tensordot_grad(self):
#We test it manually as we recreate the op in the make_node #We test it manually as we recreate the op in the make_node
amat = matrix() amat = matrix()
bmat = matrix() bmat = matrix()
gzmat = matrix() gzmat = matrix()
...@@ -3245,17 +3245,17 @@ class test_broadcast(unittest.TestCase): ...@@ -3245,17 +3245,17 @@ class test_broadcast(unittest.TestCase):
test that the unbroadcast fct don't insert not needed broadcast test that the unbroadcast fct don't insert not needed broadcast
and fuse consecutive Rebroadcast op and fuse consecutive Rebroadcast op
""" """
x=matrix() x=matrix()
assert unbroadcast(x,0) is x assert unbroadcast(x,0) is x
assert unbroadcast(x,1) is x assert unbroadcast(x,1) is x
assert unbroadcast(x,1,0) is x assert unbroadcast(x,1,0) is x
assert unbroadcast(x,0,1) is x assert unbroadcast(x,0,1) is x
assert addbroadcast(x,0) is not x assert addbroadcast(x,0) is not x
assert addbroadcast(x,1) is not x assert addbroadcast(x,1) is not x
assert addbroadcast(x,1,0).owner.inputs[0] is x assert addbroadcast(x,1,0).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x,0),0) is x assert unbroadcast(addbroadcast(x,0),0) is x
assert addbroadcast(unbroadcast(x,0),0) is not x assert addbroadcast(unbroadcast(x,0),0) is not x
x=row() x=row()
...@@ -3263,15 +3263,15 @@ class test_broadcast(unittest.TestCase): ...@@ -3263,15 +3263,15 @@ class test_broadcast(unittest.TestCase):
assert unbroadcast(x,1) is x assert unbroadcast(x,1) is x
assert unbroadcast(x,1,0) is not x assert unbroadcast(x,1,0) is not x
assert unbroadcast(x,0,1) is not x assert unbroadcast(x,0,1) is not x
assert addbroadcast(x,0) is x assert addbroadcast(x,0) is x
assert addbroadcast(x,1).owner.inputs[0] is x assert addbroadcast(x,1).owner.inputs[0] is x
assert addbroadcast(x,1,0).owner.inputs[0] is x assert addbroadcast(x,1,0).owner.inputs[0] is x
assert addbroadcast(x,0,1).owner.inputs[0] is x assert addbroadcast(x,0,1).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x,1),1) is x assert unbroadcast(addbroadcast(x,1),1) is x
assert addbroadcast(unbroadcast(x,1),1) is not x assert addbroadcast(unbroadcast(x,1),1) is not x
#the first broadcast is remove the broadcast, so the second #the first broadcast is remove the broadcast, so the second
#should not make one #should not make one
assert unbroadcast(unbroadcast(x,0),0).owner.inputs[0] is x assert unbroadcast(unbroadcast(x,0),0).owner.inputs[0] is x
...@@ -3281,10 +3281,10 @@ class test_broadcast(unittest.TestCase): ...@@ -3281,10 +3281,10 @@ class test_broadcast(unittest.TestCase):
assert unbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x assert unbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x assert addbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,0),0) is x assert addbroadcast(unbroadcast(x,0),0) is x
def test_mod(): def test_mod():
""" """
We add this test as not all language and C implementation give the same We add this test as not all language and C implementation give the same
signe to the result. This check that the c_code of `Mod` is implemented signe to the result. This check that the c_code of `Mod` is implemented
as Python. That is what we want. as Python. That is what we want.
""" """
...@@ -3298,7 +3298,7 @@ def test_mod(): ...@@ -3298,7 +3298,7 @@ def test_mod():
def test_mod_compile(): def test_mod_compile():
""" """
This test generate an Elemwise of Composite as: This test generate an Elemwise of Composite as:
Elemwise{Composite{Composite{Composite{Composite{mod,EQ},Switch},mul},add}} Elemwise{Composite{Composite{Composite{Composite{mod,EQ},Switch},mul},add}}
The c_code generated is not compiling as of 30 June 2010. I fix the compilation in the same commit. The c_code generated is not compiling as of 30 June 2010. I fix the compilation in the same commit.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论