提交 11bc4593 authored 作者: carriepl's avatar carriepl

Fix not using context_name in local_gpu_elemwise

上级 9d6615eb
......@@ -361,9 +361,9 @@ def local_gpu_elemwise(node, context_name):
for inp in node.inputs:
if inp.dtype != out_dtype:
gpu_cast_op = GpuElemwise(Cast(Scalar(out_dtype)))
new_inputs.append(gpu_cast_op(as_gpuarray_variable(inp)))
new_inputs.append(gpu_cast_op(as_gpuarray_variable(inp, context_name)))
else:
new_inputs.append(as_gpuarray_variable(inp))
new_inputs.append(as_gpuarray_variable(inp, context_name))
# Perform the exponent on the gpu and transfer the output back to the
# cpu.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论