Change dimg and dgrid setup to use img and grid's types on spatialtf grad

上级 2fbf72fa
...@@ -2985,12 +2985,8 @@ class GpuDnnTransformerGradI(DnnBase): ...@@ -2985,12 +2985,8 @@ class GpuDnnTransformerGradI(DnnBase):
if dy.ndim != 4: if dy.ndim != 4:
raise TypeError('dy must have 4 dimensions.') raise TypeError('dy must have 4 dimensions.')
dimg = GpuArrayType(dtype=img.dtype, dimg = img.type()
broadcastable=img.type.ndim * (False,), dgrid = grid.type()
context_name=context_name)()
dgrid = GpuArrayType(dtype=img.dtype,
broadcastable=img.type.ndim * (False,),
context_name=context_name)()
inputs = [img, theta, grid, grid_dims, dy, desc] inputs = [img, theta, grid, grid_dims, dy, desc]
outputs = [dimg, dgrid] outputs = [dimg, dgrid]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论