提交 658563c8 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: sentient07

Don't move to the GPU scalar int and make GraphToGPU support target.

上级 31dcc3f7
...@@ -227,6 +227,10 @@ class InputToGpuOptimizer(Optimizer): ...@@ -227,6 +227,10 @@ class InputToGpuOptimizer(Optimizer):
target = getattr(input.tag, 'target', None) target = getattr(input.tag, 'target', None)
if target == 'cpu': if target == 'cpu':
continue continue
# Do not move *int* scalar to the GPU.
if (isinstance(input.type, tensor.TensorType) and
input.ndim == 0 and 'int' in input.dtype):
continue
try: try:
new_input = host_from_gpu(GpuFromHost(target)(input)) new_input = host_from_gpu(GpuFromHost(target)(input))
...@@ -273,8 +277,12 @@ class GraphToGPU(NavigatorOptimizer): ...@@ -273,8 +277,12 @@ class GraphToGPU(NavigatorOptimizer):
# Building a new graph # Building a new graph
# Iterating through inputs of graph # Iterating through inputs of graph
for i in fgraph.inputs: for i in fgraph.inputs:
if isinstance(i.type, tensor.TensorType): # Do not move *int* scalar to the GPU.
mapping[i] = as_gpuarray_variable(i, None) # TODO context target = getattr(i.tag, 'target', None)
if (target != 'cpu' and
isinstance(i.type, tensor.TensorType) and
(i.ndim > 0 or 'int' not in i.dtype)):
mapping[i] = as_gpuarray_variable(i, target)
else: else:
mapping[i] = i mapping[i] = i
for i in fgraph.variables: for i in fgraph.variables:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论