提交 fc65d14a authored 作者: Frederic's avatar Frederic

small ger op contract refactoring to have them run on the gpu ger too.

上级 4855e5b1
...@@ -280,3 +280,11 @@ class TestGpuGer(TestGer_local_gemm_to_ger): ...@@ -280,3 +280,11 @@ class TestGpuGer(TestGer_local_gemm_to_ger):
# data on the gpu make the op always inplace # data on the gpu make the op always inplace
self.ger = gpu_ger_inplace self.ger = gpu_ger_inplace
self.gemm = tcn.blas.gpu_gemm_inplace self.gemm = tcn.blas.gpu_gemm_inplace
class TestGpuGer_OpContract(TestCase, unittest_tools.T_OpContractMixin):
def setUp(self):
self.ops = [gpu_ger_no_inplace, gpu_ger_inplace]
def clone(self, op):
return tcn.blas.GpuGer(op.inplace)
...@@ -1271,59 +1271,15 @@ class TestGer_make_node(TestCase): ...@@ -1271,59 +1271,15 @@ class TestGer_make_node(TestCase):
self.assertRaises(TypeError, ger, self.cm, self.fa, self.fv, self.dv_2) self.assertRaises(TypeError, ger, self.cm, self.fa, self.fv, self.dv_2)
self.assertRaises(TypeError, ger, self.cm, self.fa, self.fv, self.zv_2) self.assertRaises(TypeError, ger, self.cm, self.fa, self.fv, self.zv_2)
# TODO: refactor this into some place where all OpTesters could use it.
# This object name should not start with Test.
# Otherwise nosetests will execute it!
class T_OpContractMixin(object):
# self.ops should be a list of instantiations of an Op class to test.
# self.other_op should be an op which is different from every op
other_op = T.add
def copy(self, x): class TestGer_OpContract(TestCase, unittest_tools.T_OpContractMixin):
return copy(x)
def deepcopy(self, x):
return deepcopy(x)
def clone(self, op):
raise NotImplementedError('return new instance like `op`')
def test_eq(self):
for i, op_i in enumerate(self.ops):
assert op_i == op_i
assert op_i == self.copy(op_i)
assert op_i == self.deepcopy(op_i)
assert op_i == self.clone(op_i)
assert op_i != self.other_op
for j, op_j in enumerate(self.ops):
if i == j: continue
assert op_i != op_j
def test_hash(self):
for i, op_i in enumerate(self.ops):
h_i = hash(op_i)
assert h_i == hash(op_i)
assert h_i == hash(self.copy(op_i))
assert h_i == hash(self.deepcopy(op_i))
assert h_i == hash(self.clone(op_i))
assert h_i != hash(self.other_op)
for j, op_j in enumerate(self.ops):
if i == j: continue
assert op_i != hash(op_j)
def test_name(self):
for op in self.ops:
s = str(op) # show that str works
assert s # names should not be empty
class TestGer_OpContract(TestCase, T_OpContractMixin):
#TODO: These tests could be factored into a generic Op-testing base-class
def setUp(self): def setUp(self):
self.ops = [ger, ger_destructive] self.ops = [ger, ger_destructive]
def clone(self, op): def clone(self, op):
return Ger(op.destructive) return Ger(op.destructive)
class TestGer_make_thunk(TestCase): class TestGer_make_thunk(TestCase):
def setUp(self): def setUp(self):
self.rng = numpy.random.RandomState(unittest_tools.fetch_seed()) self.rng = numpy.random.RandomState(unittest_tools.fetch_seed())
......
from copy import copy, deepcopy
import sys import sys
import numpy import numpy
...@@ -104,4 +105,46 @@ class TestOptimizationMixin(object): ...@@ -104,4 +105,46 @@ class TestOptimizationMixin(object):
raise SkipTest(msg) raise SkipTest(msg)
# This object name should not start with Test.
# Otherwise nosetests will execute it!
class T_OpContractMixin(object):
# self.ops should be a list of instantiations of an Op class to test.
# self.other_op should be an op which is different from every op
other_op = T.add
def copy(self, x):
return copy(x)
def deepcopy(self, x):
return deepcopy(x)
def clone(self, op):
raise NotImplementedError('return new instance like `op`')
def test_eq(self):
for i, op_i in enumerate(self.ops):
assert op_i == op_i
assert op_i == self.copy(op_i)
assert op_i == self.deepcopy(op_i)
assert op_i == self.clone(op_i)
assert op_i != self.other_op
for j, op_j in enumerate(self.ops):
if i == j: continue
assert op_i != op_j
def test_hash(self):
for i, op_i in enumerate(self.ops):
h_i = hash(op_i)
assert h_i == hash(op_i)
assert h_i == hash(self.copy(op_i))
assert h_i == hash(self.deepcopy(op_i))
assert h_i == hash(self.clone(op_i))
assert h_i != hash(self.other_op)
for j, op_j in enumerate(self.ops):
if i == j: continue
assert op_i != hash(op_j)
def test_name(self):
for op in self.ops:
s = str(op) # show that str works
assert s # names should not be empty
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论