提交 69c71a3f authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Better handle stuff on two devices.

上级 a4fedfed
......@@ -43,6 +43,9 @@ def as_gpuarray_variable(x, context_name):
if isinstance(x.owner.op, GpuFromHost):
x = x.owner.inputs[0]
continue
if isinstance(x.owner.op, GpuToGpu):
x = x.owner.inputs[0]
continue
# If none of the conditions where met, then continue with
# the rest of the body
......@@ -51,8 +54,18 @@ def as_gpuarray_variable(x, context_name):
if hasattr(x, '_as_GpuArrayVariable'):
return x._as_GpuArrayVariable(context_name)
tensor_x = as_tensor_variable(x)
return GpuFromHost(context_name)(tensor_x)
ctx = get_context(context_name)
if isinstance(x, gpuarray.GpuArray):
if x.context.ptr != ctx.ptr:
x = x.transfer(ctx)
x = gpuarray.asarray(x, context=ctx)
bcast = [(s == 1) for s in x.shape]
return GpuArrayConstant(GpuArrayType(dtype=x.dtype,
broadcastable=bcast,
context_name=context_name))
def infer_context_name(*vars):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论