提交 e20f07e6 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add back support for tensors in as_gpuarray_variable().

上级 f332373c
...@@ -51,9 +51,15 @@ def as_gpuarray_variable(x, context_name): ...@@ -51,9 +51,15 @@ def as_gpuarray_variable(x, context_name):
# the rest of the body # the rest of the body
break break
# If we couldn't deal with transfers, then maybe it's a tensor
if isinstance(x.type, tensor.TensorType):
return GpuFromHost(context_name)(x)
# Try _as_GpuArrayVariable if possible
if hasattr(x, '_as_GpuArrayVariable'): if hasattr(x, '_as_GpuArrayVariable'):
return x._as_GpuArrayVariable(context_name) return x._as_GpuArrayVariable(context_name)
# If it didn't work try for a constant
ctx = get_context(context_name) ctx = get_context(context_name)
if isinstance(x, gpuarray.GpuArray): if isinstance(x, gpuarray.GpuArray):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论