Change dimg and dgrid setup to use img and grid's types on spatialtf grad

上级 2fbf72fa
......@@ -2985,12 +2985,8 @@ class GpuDnnTransformerGradI(DnnBase):
if dy.ndim != 4:
raise TypeError('dy must have 4 dimensions.')
dimg = GpuArrayType(dtype=img.dtype,
broadcastable=img.type.ndim * (False,),
context_name=context_name)()
dgrid = GpuArrayType(dtype=img.dtype,
broadcastable=img.type.ndim * (False,),
context_name=context_name)()
dimg = img.type()
dgrid = grid.type()
inputs = [img, theta, grid, grid_dims, dy, desc]
outputs = [dimg, dgrid]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论