提交 81245414 authored 作者: Frederic Bastien's avatar Frederic Bastien

gpu opt, don't move float16 ops or cast to/from float32 in the graph for some…

gpu opt, don't move float16 ops or cast to/from float32 in the graph for some ops that don't support float16
上级 8735401b
...@@ -1327,6 +1327,8 @@ theano.tensor.nnet.conv2d() ...@@ -1327,6 +1327,8 @@ theano.tensor.nnet.conv2d()
@op_lifter([SparseBlockGemv]) @op_lifter([SparseBlockGemv])
@register_opt2([SparseBlockGemv], 'fast_compile') @register_opt2([SparseBlockGemv], 'fast_compile')
def local_gpua_sparseblockgemv(op, context_name, inputs, outputs): def local_gpua_sparseblockgemv(op, context_name, inputs, outputs):
if inputs[0].dtype == 'float16':
return
if op.inplace: if op.inplace:
return gpu_sparse_block_gemv_inplace return gpu_sparse_block_gemv_inplace
else: else:
...@@ -1337,6 +1339,8 @@ def local_gpua_sparseblockgemv(op, context_name, inputs, outputs): ...@@ -1337,6 +1339,8 @@ def local_gpua_sparseblockgemv(op, context_name, inputs, outputs):
@op_lifter([SparseBlockOuter]) @op_lifter([SparseBlockOuter])
@register_opt2([SparseBlockOuter], 'fast_compile') @register_opt2([SparseBlockOuter], 'fast_compile')
def local_gpua_sparseblockouter(op, context_name, inputs, outputs): def local_gpua_sparseblockouter(op, context_name, inputs, outputs):
if inputs[0].dtype == 'float16':
return
if op.inplace: if op.inplace:
return gpu_sparse_block_outer_inplace return gpu_sparse_block_outer_inplace
else: else:
...@@ -1990,9 +1994,15 @@ def local_gpu_maxandargmax(op, context_name, inputs, outputs): ...@@ -1990,9 +1994,15 @@ def local_gpu_maxandargmax(op, context_name, inputs, outputs):
def local_gpu_solve(op, context_name, inputs, outputs): def local_gpu_solve(op, context_name, inputs, outputs):
if not cusolver_available: if not cusolver_available:
return return
if inputs[0].dtype not in ['float16', 'float32']:
return
if op.A_structure not in MATRIX_STRUCTURES_SOLVE: if op.A_structure not in MATRIX_STRUCTURES_SOLVE:
return return
return GpuCusolverSolve(A_structure=op.A_structure) op = GpuCusolverSolve(A_structure=op.A_structure)
if inputs[0].dtype == 'float16':
return op(inputs[0].astype('float32'),
inputs[1].astype('float32')).astype('float16')
return op
@register_inplace() @register_inplace()
...@@ -2010,7 +2020,13 @@ def local_inplace_gpu_solve(node): ...@@ -2010,7 +2020,13 @@ def local_inplace_gpu_solve(node):
def local_gpu_cholesky(op, context_name, inputs, outputs): def local_gpu_cholesky(op, context_name, inputs, outputs):
if not cusolver_available: if not cusolver_available:
return return
return GpuCholesky(lower=op.lower, inplace=op.destructive) if inputs[0].dtype not in ['float16', 'float32']:
return
op = GpuCholesky(lower=op.lower, inplace=op.destructive)
if inputs[0].dtype == 'float16':
return op(inputs[0].astype('float32')).astype('float16')
return op
@register_inplace() @register_inplace()
...@@ -2026,7 +2042,12 @@ def local_inplace_cholesky(node): ...@@ -2026,7 +2042,12 @@ def local_inplace_cholesky(node):
def local_gpu_matrix_inverse(op, context_name, inputs, outputs): def local_gpu_matrix_inverse(op, context_name, inputs, outputs):
if not config.magma.enabled: if not config.magma.enabled:
return return
return GpuMagmaMatrixInverse() if inputs[0].dtype not in ['float16', 'float32']:
return
op = GpuMagmaMatrixInverse()
if inputs[0].dtype == 'float16':
return op(inputs[0].astype('float32')).astype('float16')
return op
@register_inplace() @register_inplace()
...@@ -2043,9 +2064,13 @@ def local_inplace_matrix_inverse_inplace(node): ...@@ -2043,9 +2064,13 @@ def local_inplace_matrix_inverse_inplace(node):
def local_gpu_svd(op, context_name, inputs, outputs): def local_gpu_svd(op, context_name, inputs, outputs):
if not config.magma.enabled: if not config.magma.enabled:
return return
return GpuMagmaSVD(full_matrices=op.full_matrices, if inputs[0].dtype not in ['float16', 'float32']:
return
op = GpuMagmaSVD(full_matrices=op.full_matrices,
compute_uv=op.compute_uv) compute_uv=op.compute_uv)
if inputs[0].dtype == 'float16':
return op(inputs[0].astype('float32')).astype('float16')
return op
# Do not register in fast_run or fast_compile. # Do not register in fast_run or fast_compile.
# It will be added to fast_run if the GPU is enabled. # It will be added to fast_run if the GPU is enabled.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论