提交 b043dece authored 作者: notoraptor's avatar notoraptor

Simplify testing.

上级 274955e4
...@@ -80,57 +80,41 @@ class TestMathErrorFunctions(TestCase): ...@@ -80,57 +80,41 @@ class TestMathErrorFunctions(TestCase):
theano.printing.debugprint(theano_function) theano.printing.debugprint(theano_function)
return False return False
def compute_erfinv_host(self, dtype): def test_elemwise_erfinv(self):
vector = theano.tensor.vector(dtype=dtype) for dtype in self.dtypes:
output = theano.tensor.erfinv(vector)
f = theano.function([vector], output, name='HOST/erfinv/' + dtype, mode=mode_without_gpu)
assert len([n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, GpuElemwise)]) == 0
vector_val = self.default_arrays[dtype]
f(vector_val)
out = f(vector_val)
assert_allclose(self.expected_erfinv_outputs[dtype], out)
def compute_erfinv_gpu(self, dtype):
vector = theano.tensor.vector(dtype=dtype) vector = theano.tensor.vector(dtype=dtype)
output = theano.tensor.erfinv(vector) output = theano.tensor.erfinv(vector)
f = theano.function([vector], output, name='GPU/erfinv/' + dtype, mode=mode_with_gpu) f_host = theano.function([vector], output, name='HOST/erfinv/' + dtype, mode=mode_without_gpu)
f_gpu = theano.function([vector], output, name='GPU/erfinv/' + dtype, mode=mode_with_gpu)
assert len([n for n in f_host.maker.fgraph.apply_nodes if isinstance(n.op, GpuElemwise)]) == 0
if not theano.config.device.startswith('opencl'): if not theano.config.device.startswith('opencl'):
assert self.check_gpu_scalar_op(f, GpuErfinv), 'Function graph does not contains scalar op "GpuErfinv".' assert self.check_gpu_scalar_op(f_gpu, GpuErfinv), \
vector_val = self.default_arrays[dtype] 'Function graph does not contains scalar op "GpuErfinv".'
f(vector_val)
out = f(vector_val)
assert_allclose(self.expected_erfinv_outputs[dtype], out)
def compute_erfcinv_host(self, dtype):
vector = theano.tensor.vector(dtype=dtype)
output = theano.tensor.erfcinv(vector)
f = theano.function([vector], output, name='HOST/erfcinv/' + dtype, mode=mode_without_gpu)
assert len([n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, GpuElemwise)]) == 0
vector_val = self.default_arrays[dtype] vector_val = self.default_arrays[dtype]
f(vector_val) f_host(vector_val)
out = f(vector_val) f_gpu(vector_val)
assert_allclose(self.expected_erfcinv_outputs[dtype], out) out_host = f_host(vector_val)
out_gpu = f_gpu(vector_val)
assert_allclose(out_host, out_gpu)
assert_allclose(self.expected_erfinv_outputs[dtype], out_gpu)
def compute_erfcinv_gpu(self, dtype): def test_elemwise_erfcinv(self):
for dtype in self.dtypes:
vector = theano.tensor.vector(dtype=dtype) vector = theano.tensor.vector(dtype=dtype)
output = theano.tensor.erfcinv(vector) output = theano.tensor.erfcinv(vector)
f = theano.function([vector], output, name='GPU/erfcinv/' + dtype, mode=mode_with_gpu) f_host = theano.function([vector], output, name='HOST/erfcinv/' + dtype, mode=mode_without_gpu)
f_gpu = theano.function([vector], output, name='GPU/erfcinv/' + dtype, mode=mode_with_gpu)
assert len([n for n in f_host.maker.fgraph.apply_nodes if isinstance(n.op, GpuElemwise)]) == 0
if not theano.config.device.startswith('opencl'): if not theano.config.device.startswith('opencl'):
assert self.check_gpu_scalar_op(f, GpuErfcinv), 'Function graph does not contains scalar op "GpuErfcinv".' assert self.check_gpu_scalar_op(f_gpu, GpuErfcinv), \
'Function graph does not contains scalar op "GpuErfcinv".'
vector_val = self.default_arrays[dtype] vector_val = self.default_arrays[dtype]
f(vector_val) f_host(vector_val)
out = f(vector_val) f_gpu(vector_val)
assert_allclose(self.expected_erfcinv_outputs[dtype], out) out_host = f_host(vector_val)
out_gpu = f_gpu(vector_val)
def test_elemwise_erfinv(self): assert_allclose(out_host, out_gpu)
for dtype in self.dtypes: assert_allclose(self.expected_erfcinv_outputs[dtype], out_gpu)
self.compute_erfinv_host(dtype)
self.compute_erfinv_gpu(dtype)
def test_elemwise_erfcinv(self):
for dtype in self.dtypes:
self.compute_erfcinv_host(dtype)
self.compute_erfcinv_gpu(dtype)
class test_float16(): class test_float16():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论