提交 b954ba93 authored 作者: João Victor Risso's avatar João Victor Risso

Check desc type in GpuDnnTransformerGradT make_node method

上级 6ee06010
......@@ -2950,6 +2950,10 @@ class GpuDnnTransformerGradT(DnnBase):
def make_node(self, dgrid, desc):
context_name = infer_context_name(desc)
if (not isinstance(desc.type, CDataType) or
desc.type.ctype != 'cudnnSpatialTransformerDescriptor_t'):
raise ValueError('desc must be cudnnSpatialTransformerDescriptor_t')
dgrid = as_gpuarray_variable(dgrid, context_name)
assert dgrid.dtype in ('float16', 'float32', 'float64')
assert dgrid.ndim == 4
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论