提交 11bc4593 authored 作者: carriepl's avatar carriepl

Fix not using context_name in local_gpu_elemwise

上级 9d6615eb
...@@ -361,9 +361,9 @@ def local_gpu_elemwise(node, context_name): ...@@ -361,9 +361,9 @@ def local_gpu_elemwise(node, context_name):
for inp in node.inputs: for inp in node.inputs:
if inp.dtype != out_dtype: if inp.dtype != out_dtype:
gpu_cast_op = GpuElemwise(Cast(Scalar(out_dtype))) gpu_cast_op = GpuElemwise(Cast(Scalar(out_dtype)))
new_inputs.append(gpu_cast_op(as_gpuarray_variable(inp))) new_inputs.append(gpu_cast_op(as_gpuarray_variable(inp, context_name)))
else: else:
new_inputs.append(as_gpuarray_variable(inp)) new_inputs.append(as_gpuarray_variable(inp, context_name))
# Perform the exponent on the gpu and transfer the output back to the # Perform the exponent on the gpu and transfer the output back to the
# cpu. # cpu.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论