提交 f5abe95e authored 作者: João Victor Risso's avatar João Victor Risso

Add checks for dimensions' type and ndims in GpuDnnTransformerDesc

上级 efb57395
......@@ -2869,6 +2869,10 @@ class GpuDnnTransformerDesc(COp):
def make_node(self, dimensions):
dimensions = as_tensor_variable(dimensions)
assert dimensions.dtype in theano.tensor.basic.int_dtypes
assert dimensions.ndim == 1
dimensions = theano.tensor.basic.cast(dimensions, 'int64')
node = Apply(self, [dimensions],
[CDataType("cudnnSpatialTransformerDescriptor_t",
freefunc="cudnnDestroySpatialTransformerDescriptor")()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论