提交 cd849526 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Address comments from review.

上级 e6914181
...@@ -611,6 +611,7 @@ def local_gpua_gemm(node): ...@@ -611,6 +611,7 @@ def local_gpua_gemm(node):
def local_gpua_hgemm(node): def local_gpua_hgemm(node):
from theano.sandbox.cuda import nvcc_compiler from theano.sandbox.cuda import nvcc_compiler
if nvcc_compiler.nvcc_version < '7.5': if nvcc_compiler.nvcc_version < '7.5':
log.warning("Not performing dot of float16 on the GPU since cuda 7.5 is not available. Updating could speed up your code.")
return return
A = node.inputs[0] A = node.inputs[0]
B = node.inputs[1] B = node.inputs[1]
......
...@@ -132,6 +132,7 @@ def test_hgemm_swap(): ...@@ -132,6 +132,7 @@ def test_hgemm_swap():
f = theano.function([v, m], tensor.dot(v, m), mode=mode_with_gpu) f = theano.function([v, m], tensor.dot(v, m), mode=mode_with_gpu)
assert len([node for node in f.maker.fgraph.apply_nodes assert len([node for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, GpuGemm)]) == 0 if isinstance(node.op, GpuGemm)]) == 0
f = theano.function([m32, m], tensor.dot(m32, m), mode=mode_with_gpu) f = theano.function([m32, m], tensor.dot(m32, m), mode=mode_with_gpu)
assert len([node for node in f.maker.fgraph.apply_nodes assert len([node for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, GpuGemm)]) == 0 if isinstance(node.op, GpuGemm)]) == 0
...@@ -140,20 +141,6 @@ def test_hgemm_swap(): ...@@ -140,20 +141,6 @@ def test_hgemm_swap():
assert len([node for node in f.maker.fgraph.apply_nodes assert len([node for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, GpuGemm)]) == 1 if isinstance(node.op, GpuGemm)]) == 1
def test_hgemm_value():
from theano.sandbox.cuda import nvcc_compiler
if nvcc_compiler.nvcc_version < '7.5':
raise SkipTest("SgemmEx is only avaialble on cuda 7.5+")
m = tensor.matrix(dtype='float16')
m2 = tensor.matrix(dtype='float16')
f = theano.function([m, m2], tensor.dot(m, m2), mode=mode_with_gpu)
assert len([node for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, GpuGemm)]) == 1
v1 = numpy.random.random((3, 4)).astype('float16') v1 = numpy.random.random((3, 4)).astype('float16')
v2 = numpy.random.random((4, 2)).astype('float16') v2 = numpy.random.random((4, 2)).astype('float16')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论