提交 633d89f2 authored 作者: Bart van Merrienboer's avatar Bart van Merrienboer

Move values_eq_approx outside of function to allow pickling

上级 8fc19902
......@@ -1195,6 +1195,23 @@ def local_conv_fft_full(node):
return
def values_eq_approx_high_tol(a, b):
"""This fct is needed to don't have DebugMode raise useless
error due to ronding error.
This happen as We reduce on the two last dimensions, so this
can raise the absolute error if the number of element we
reduce on is significant.
"""
assert a.ndim == 4
atol = None
if a.shape[-1] * a.shape[-2] > 100:
# For float32 the default atol is 1e-5
atol = 3e-5
return CudaNdarrayType.values_eq_approx(a, b, atol=atol)
@local_optimizer([gpu_from_host, conv.ConvOp])
def local_gpu_conv(node):
"""
......@@ -1244,22 +1261,6 @@ def local_gpu_conv(node):
return make_graph
return ret
def values_eq_approx(a, b):
"""This fct is needed to don't have DebugMode raise useless
error due to ronding error.
This happen as We reduce on the two last dimensions, so this
can raise the absolute error if the number of element we
reduce on is significant.
"""
assert a.ndim == 4
atol = None
if a.shape[-1] * a.shape[-2] > 100:
#For float32 the default atol is 1e-5
atol = 3e-5
return CudaNdarrayType.values_eq_approx(a, b, atol=atol)
if isinstance(node.op, GpuFromHost):
#gpu_from_host(conv) -> gpu_conv(gpu_from_host)
host_input = node.inputs[0]
......@@ -1272,7 +1273,7 @@ def local_gpu_conv(node):
gpu_from_host(kern))
out = tensor.patternbroadcast(out,
node.outputs[0].broadcastable)
out.values_eq_approx = values_eq_approx
out.values_eq_approx = values_eq_approx_high_tol
# in some case the ConvOp broadcast the last 2 dimensions
# differently then the gpu ConvOp
return [out]
......@@ -1291,7 +1292,7 @@ def local_gpu_conv(node):
out = tensor.patternbroadcast(
host_from_gpu(out),
node.outputs[0].broadcastable)
out.values_eq_approx = values_eq_approx
out.values_eq_approx = values_eq_approx_high_tol
# in some case the ConvOp broadcast the last 2 dimensions
# differently then the gpu ConvOp
return [out]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论