提交 cd849526 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Address comments from review.

上级 e6914181
......@@ -611,6 +611,7 @@ def local_gpua_gemm(node):
def local_gpua_hgemm(node):
from theano.sandbox.cuda import nvcc_compiler
if nvcc_compiler.nvcc_version < '7.5':
log.warning("Not performing dot of float16 on the GPU since cuda 7.5 is not available. Updating could speed up your code.")
return
A = node.inputs[0]
B = node.inputs[1]
......
......@@ -132,6 +132,7 @@ def test_hgemm_swap():
f = theano.function([v, m], tensor.dot(v, m), mode=mode_with_gpu)
assert len([node for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, GpuGemm)]) == 0
f = theano.function([m32, m], tensor.dot(m32, m), mode=mode_with_gpu)
assert len([node for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, GpuGemm)]) == 0
......@@ -140,20 +141,6 @@ def test_hgemm_swap():
assert len([node for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, GpuGemm)]) == 1
def test_hgemm_value():
from theano.sandbox.cuda import nvcc_compiler
if nvcc_compiler.nvcc_version < '7.5':
raise SkipTest("SgemmEx is only avaialble on cuda 7.5+")
m = tensor.matrix(dtype='float16')
m2 = tensor.matrix(dtype='float16')
f = theano.function([m, m2], tensor.dot(m, m2), mode=mode_with_gpu)
assert len([node for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, GpuGemm)]) == 1
v1 = numpy.random.random((3, 4)).astype('float16')
v2 = numpy.random.random((4, 2)).astype('float16')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论