Fix check on dy's number of dimensions

上级 9181d894
...@@ -2982,8 +2982,8 @@ class GpuDnnTransformerGradI(DnnBase): ...@@ -2982,8 +2982,8 @@ class GpuDnnTransformerGradI(DnnBase):
grid_dims = as_tensor_variable(desc.owner.inputs[0]) grid_dims = as_tensor_variable(desc.owner.inputs[0])
dy = as_gpuarray_variable(dy, context_name) dy = as_gpuarray_variable(dy, context_name)
if img.ndim != 4: if dy.ndim != 4:
raise TypeError('img must have 4 dimensions.') raise TypeError('dy must have 4 dimensions.')
dimg = GpuArrayType(dtype=img.dtype, dimg = GpuArrayType(dtype=img.dtype,
broadcastable=img.type.ndim * (False,), broadcastable=img.type.ndim * (False,),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论