提交 659b7c8f authored 作者: João Victor Risso's avatar João Victor Risso

Add type and dimensions check to dgrid in GpuDnnTransformerGradT Op

上级 cac8567b
...@@ -3018,7 +3018,12 @@ class GpuDnnTransformerGradT(DnnBase): ...@@ -3018,7 +3018,12 @@ class GpuDnnTransformerGradT(DnnBase):
DnnBase.__init__(self, ["c_code/dnn_sptf_gt.c"], "APPLY_SPECIFIC(dnn_sptf_gt)") DnnBase.__init__(self, ["c_code/dnn_sptf_gt.c"], "APPLY_SPECIFIC(dnn_sptf_gt)")
def make_node(self, dgrid, desc): def make_node(self, dgrid, desc):
context_name = infer_context_name(dgrid) context_name = infer_context_name(desc)
dgrid = as_gpuarray_variable(dgrid, context_name)
assert dgrid.dtype in ('float16', 'float32', 'float64')
assert dgrid.ndim == 4
dtheta = GpuArrayType(dtype=dgrid.dtype, dtheta = GpuArrayType(dtype=dgrid.dtype,
broadcastable=(dgrid.type.ndim - 1) * (False,), broadcastable=(dgrid.type.ndim - 1) * (False,),
context_name=context_name)() context_name=context_name)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论