提交 6328000d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Apply same workaround for other descriptors

上级 045eff2e
......@@ -1322,9 +1322,16 @@ class GpuDnnPoolDesc(GpuOp):
if self.pad != (0, 0) and version() == -1:
raise RuntimeError("CuDNN pooling with padding requires CuDNN v2")
return Apply(self, [],
node = Apply(self, [],
[CDataType("cudnnPoolingDescriptor_t",
freefunc="cudnnDestroyPoolingDescriptor")()])
# DebugMode cannot compare the values of CDataType variables, so by
# default it returns False all the time. To prevent DebugMode from
# complaining because of the MergeOptimizer, we make this variable
# always compare to True.
out = node.outputs[0]
out.tag.values_eq_approx = tensor.type.values_eq_approx_always_true
return node
def c_code(self, node, name, inputs, outputs, sub):
desc, = outputs
......
......@@ -307,9 +307,16 @@ class GpuDnnConvDesc(COp):
if kern_shape.type.ndim != 1 or kern_shape.type.dtype != 'int64':
raise TypeError('kern must be 1D shape tensor')
return Apply(self, [kern_shape],
node = Apply(self, [kern_shape],
[CDataType("cudnnConvolutionDescriptor_t",
freefunc="cudnnDestroyConvolutionDescriptor")()])
# DebugMode cannot compare the values of CDataType variables, so by
# default it returns False all the time. To prevent DebugMode from
# complaining because of the MergeOptimizer, we make this variable
# always compare to True.
out = node.outputs[0]
out.tag.values_eq_approx = tensor.type.values_eq_approx_always_true
return node
def get_op_params(self):
pad0 = '0'
......@@ -998,9 +1005,16 @@ class GpuDnnPoolDesc(Op):
self.pad = (0, 0)
def make_node(self):
return Apply(self, [],
node = Apply(self, [],
[CDataType("cudnnPoolingDescriptor_t",
freefunc="cudnnDestroyPoolingDescriptor")()])
# DebugMode cannot compare the values of CDataType variables, so by
# default it returns False all the time. To prevent DebugMode from
# complaining because of the MergeOptimizer, we make this variable
# always compare to True.
out = node.outputs[0]
out.tag.values_eq_approx = tensor.type.values_eq_approx_always_true
return node
def c_code(self, node, name, inputs, outputs, sub):
desc, = outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论