提交 242fcbfb authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Another test for Ger

上级 fedca15d
......@@ -390,6 +390,26 @@ class TestGpuGer(TestGer):
self.ger = gpu_ger_inplace
self.gemm = tcn.blas.gpu_gemm_inplace
class TestGpuGerNoTransfer(TestGer):
@staticmethod
def shared(val):
try:
return tcn.shared_constructor(val)
except TypeError:
return theano.shared(val)
def setUp(self):
self.mode = mode_with_gpu
dtype = self.dtype = 'float32' # optimization isn't dtype-dependent
self.A = tensor.tensor(dtype=dtype, broadcastable=(False, False))
self.a = tensor.tensor(dtype=dtype, broadcastable=())
self.x = tensor.tensor(dtype=dtype, broadcastable=(False,))
self.y = tensor.tensor(dtype=dtype, broadcastable=(False,))
# data on the gpu make the op always inplace
self.ger = gpu_ger_inplace
self.ger_destructive = gpu_ger_inplace
self.gemm = tcn.blas.gpu_gemm_inplace
class TestGpuGer_OpContract(TestCase, unittest_tools.T_OpContractMixin):
def setUp(self):
......
......@@ -1355,6 +1355,7 @@ class TestGer_OpContract(TestCase, unittest_tools.T_OpContractMixin):
class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
shared = staticmethod(theano.shared)
def setUp(self):
self.mode = theano.compile.get_default_mode().including('fast_run')
......@@ -1480,7 +1481,7 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
return self.given_dtype('complex128', 1, 9)
def test_inplace(self):
A = theano.shared(numpy.random.rand(4, 5).astype(self.dtype))
A = self.shared(numpy.random.rand(4, 5).astype(self.dtype))
f = self.function([self.x, self.y], [],
updates={A: A + T.constant(0.1, dtype=self.dtype) *
T.outer(self.x, self.y)})
......@@ -1488,6 +1489,8 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
f(numpy.random.rand(4).astype(self.dtype),
numpy.random.rand(5).astype(self.dtype))
A.set_value(A.get_value(borrow=True)[::-1, ::-1], borrow=True)
A.set_value(
A.get_value(borrow=True, return_internal_type=True)[::-1, ::-1],
borrow=True)
f(numpy.random.rand(4).astype(self.dtype),
numpy.random.rand(5).astype(self.dtype))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论