提交 7200fdfd authored 作者: Frederic Bastien's avatar Frederic Bastien

added infer_shape() to HostFromGpu and GpuFromHost

上级 adb52865
......@@ -44,6 +44,8 @@ class HostFromGpu(Op):
def grad(self, inputs, (gz,)):
return gz,
#return [GpuFromHost()(gz)]
def infer_shape(self, node, xshp):
return xshp
host_from_gpu = HostFromGpu()
class GpuFromHost(Op):
......@@ -62,6 +64,8 @@ class GpuFromHost(Op):
def grad(self, inputs, (gz,)):
return gz,
#return [HostFromGpu()(gz)]
def infer_shape(self, node, xshp):
return xshp
gpu_from_host = GpuFromHost()
class GpuElemwise(Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论