提交 5f3c1916 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix implementations of R_op().

上级 bf828045
...@@ -333,10 +333,7 @@ class HostFromGpu(Op): ...@@ -333,10 +333,7 @@ class HostFromGpu(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
ev, = eval_points ev, = eval_points
if isinstance(ev, tensor.TensorType): return self(ev)
return [GpuFromHost(inputs[0].type.context_name)(ev)]
else:
return [ev]
def infer_shape(self, node, xshp): def infer_shape(self, node, xshp):
return xshp return xshp
...@@ -377,10 +374,7 @@ class GpuFromHost(Op): ...@@ -377,10 +374,7 @@ class GpuFromHost(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
ev, = eval_points ev, = eval_points
if isinstance(ev, GpuArrayType): return self(ev)
return [host_from_gpu(ev)]
else:
return [ev]
def infer_shape(self, node, xshp): def infer_shape(self, node, xshp):
return xshp return xshp
...@@ -441,7 +435,7 @@ class GpuToGpu(Op): ...@@ -441,7 +435,7 @@ class GpuToGpu(Op):
return [GpuToGpu(inputs[0].type.context_name)(gz)] return [GpuToGpu(inputs[0].type.context_name)(gz)]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return self.grad(inputs, eval_points) return self(eval_points[0])
def infer_shape(self, node, xshp): def infer_shape(self, node, xshp):
return xshp return xshp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论