提交 b6eaee58 authored 作者: Frederic Bastien's avatar Frederic Bastien

Small change found while doing review.

上级 d12e8b81
...@@ -272,7 +272,6 @@ class Ger(Op): ...@@ -272,7 +272,6 @@ class Ger(Op):
cZ, = node_output_storage cZ, = node_output_storage
def rval(): def rval():
A = cA[0]
if self.destructive: if self.destructive:
A = cA[0] A = cA[0]
else: else:
......
...@@ -1102,49 +1102,49 @@ class TestGer_make_node(TestCase): ...@@ -1102,49 +1102,49 @@ class TestGer_make_node(TestCase):
self.ca = T.cscalar() self.ca = T.cscalar()
self.za = T.zscalar() self.za = T.zscalar()
def test_works_on_all_valid_dtypes(s): def test_works_on_all_valid_dtypes(self):
s.assertEquals(s.fm.type, self.assertEquals(self.fm.type,
ger(s.fm, s.fa, s.fv, s.fv_2).type) ger(self.fm, self.fa, self.fv, self.fv_2).type)
s.assertEquals(s.fm.type, self.assertEquals(self.fm.type,
ger(s.fm, s.fa, s.fv, s.fv_2).type) ger(self.fm, self.fa, self.fv, self.fv_2).type)
s.assertEquals(s.fm.type, self.assertEquals(self.fm.type,
ger(s.fm, s.fa, s.fv, s.fv_2).type) ger(self.fm, self.fa, self.fv, self.fv_2).type)
s.assertEquals(s.fm.type, self.assertEquals(self.fm.type,
ger(s.fm, s.fa, s.fv, s.fv_2).type) ger(self.fm, self.fa, self.fv, self.fv_2).type)
def test_fails_on_invalid_matching_dtypes(s): def test_fails_on_invalid_dtypes(self):
s.assertRaises(TypeError, self.assertRaises(TypeError,
ger, T.imatrix(), T.iscalar(), T.ivector(), ger, T.imatrix(), T.iscalar(), T.ivector(),
T.ivector()) T.ivector())
def test_fails_for_nonscalar_alpha(s): def test_fails_for_nonscalar_alpha(self):
s.assertRaises(TypeError, self.assertRaises(TypeError,
ger, s.fm, s.fm, s.fv, s.fv_2) ger, self.fm, self.fm, self.fv, self.fv_2)
# boundary case - fv1 has the right dtype and could be dimshuffled to a # boundary case - fv1 has the right dtype and could be dimshuffled to a
# scalar, but that's not make_node's job. # scalar, but that's not make_node's job.
s.assertRaises(TypeError, self.assertRaises(TypeError,
ger, s.fm, s.fv1, s.fv, s.fv_2) ger, self.fm, self.fv1, self.fv, self.fv_2)
# actually doing the aforementioned dimshuffle makes it work # actually doing the aforementioned dimshuffle makes it work
s.assertEquals(s.fm.type, self.assertEquals(self.fm.type,
ger(s.fm, s.fv1.dimshuffle(), s.fv, s.fv_2).type) ger(self.fm, self.fv1.dimshuffle(), self.fv, self.fv_2).type)
def test_fails_for_nonmatrix_A(s): def test_fails_for_nonmatrix_A(self):
s.assertRaises(TypeError, self.assertRaises(TypeError,
ger, s.fv, s.fa, s.fv, s.fv_2) ger, self.fv, self.fa, self.fv, self.fv_2)
def test_fails_for_nonvector_x_or_y(s): def test_fails_for_nonvector_x_or_y(self):
s.assertRaises(TypeError, self.assertRaises(TypeError,
ger, s.fm, s.fa, s.fv.dimshuffle('x', 0), s.fv_2) ger, self.fm, self.fa, self.fv.dimshuffle('x', 0), self.fv_2)
s.assertRaises(TypeError, self.assertRaises(TypeError,
ger, s.fm, s.fa, s.fv, s.fv_2.dimshuffle('x', 0)) ger, self.fm, self.fa, self.fv, self.fv_2.dimshuffle('x', 0))
def test_fails_for_mixed_dtypes(s): def test_fails_for_mixed_dtypes(self):
s.assertRaises(TypeError, ger, s.dm, s.fa, s.fv, s.fv_2) self.assertRaises(TypeError, ger, self.dm, self.fa, self.fv, self.fv_2)
s.assertRaises(TypeError, ger, s.fm, s.da, s.fv, s.fv_2) self.assertRaises(TypeError, ger, self.fm, self.da, self.fv, self.fv_2)
s.assertRaises(TypeError, ger, s.fm, s.fa, s.dv, s.fv_2) self.assertRaises(TypeError, ger, self.fm, self.fa, self.dv, self.fv_2)
s.assertRaises(TypeError, ger, s.fm, s.fa, s.fv, s.dv_2) self.assertRaises(TypeError, ger, self.fm, self.fa, self.fv, self.dv_2)
s.assertRaises(TypeError, ger, s.cm, s.fa, s.fv, s.dv_2) self.assertRaises(TypeError, ger, self.cm, self.fa, self.fv, self.dv_2)
s.assertRaises(TypeError, ger, s.cm, s.fa, s.fv, s.zv_2) self.assertRaises(TypeError, ger, self.cm, self.fa, self.fv, self.zv_2)
# TODO: refactor this into some place where all OpTesters could use it. # TODO: refactor this into some place where all OpTesters could use it.
class TestOpContractMixin(object): class TestOpContractMixin(object):
...@@ -1161,7 +1161,7 @@ class TestOpContractMixin(object): ...@@ -1161,7 +1161,7 @@ class TestOpContractMixin(object):
def clone(self, op): def clone(self, op):
raise NotImplementedError('return new instance like `op`') raise NotImplementedError('return new instance like `op`')
def test_eq_ger(self): def test_eq(self):
for i, op_i in enumerate(self.ops): for i, op_i in enumerate(self.ops):
assert op_i == op_i assert op_i == op_i
assert op_i == self.copy(op_i) assert op_i == self.copy(op_i)
...@@ -1253,15 +1253,15 @@ class TestGer_make_thunk(TestCase): ...@@ -1253,15 +1253,15 @@ class TestGer_make_thunk(TestCase):
assert storage_map[sZ][0].dtype == dtype assert storage_map[sZ][0].dtype == dtype
assert storage_map[sZ][0].shape == (M, N) assert storage_map[sZ][0].shape == (M, N)
def test_f32_0_0(s): return s.given_dtype('float32', 0, 0) def test_f32_0_0(self): return self.given_dtype('float32', 0, 0)
def test_f32_1_0(s): return s.given_dtype('float32', 1, 0) def test_f32_1_0(self): return self.given_dtype('float32', 1, 0)
def test_f32_0_1(s): return s.given_dtype('float32', 0, 1) def test_f32_0_1(self): return self.given_dtype('float32', 0, 1)
def test_f32_1_1(s): return s.given_dtype('float32', 1, 1) def test_f32_1_1(self): return self.given_dtype('float32', 1, 1)
def test_f32_4_4(s): return s.given_dtype('float32', 4, 4) def test_f32_4_4(self): return self.given_dtype('float32', 4, 4)
def test_f64_4_5(s): return s.given_dtype('float64', 4, 5) def test_f64_4_5(self): return self.given_dtype('float64', 4, 5)
def test_c64_7_1(s): return s.given_dtype('complex64', 7, 1) def test_c64_7_1(self): return self.given_dtype('complex64', 7, 1)
def test_c128_1_9(s): return s.given_dtype('complex128', 1, 9) def test_c128_1_9(self): return self.given_dtype('complex128', 1, 9)
# TODO: Refactor and add to this base class as we refactor test code. # TODO: Refactor and add to this base class as we refactor test code.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论