提交 1751f789 authored 作者: abergeron's avatar abergeron

Merge pull request #2519 from bartvm/values_eq_approx_move

Move values_eq_approx outside of function to allow pickling
......@@ -1195,6 +1195,23 @@ def local_conv_fft_full(node):
return
def values_eq_approx_high_tol(a, b):
"""This fct is needed to don't have DebugMode raise useless
error due to ronding error.
This happen as We reduce on the two last dimensions, so this
can raise the absolute error if the number of element we
reduce on is significant.
"""
assert a.ndim == 4
atol = None
if a.shape[-1] * a.shape[-2] > 100:
# For float32 the default atol is 1e-5
atol = 3e-5
return CudaNdarrayType.values_eq_approx(a, b, atol=atol)
@local_optimizer([gpu_from_host, conv.ConvOp])
def local_gpu_conv(node):
"""
......@@ -1244,22 +1261,6 @@ def local_gpu_conv(node):
return make_graph
return ret
def values_eq_approx(a, b):
"""This fct is needed to don't have DebugMode raise useless
error due to ronding error.
This happen as We reduce on the two last dimensions, so this
can raise the absolute error if the number of element we
reduce on is significant.
"""
assert a.ndim == 4
atol = None
if a.shape[-1] * a.shape[-2] > 100:
#For float32 the default atol is 1e-5
atol = 3e-5
return CudaNdarrayType.values_eq_approx(a, b, atol=atol)
if isinstance(node.op, GpuFromHost):
#gpu_from_host(conv) -> gpu_conv(gpu_from_host)
host_input = node.inputs[0]
......@@ -1272,7 +1273,7 @@ def local_gpu_conv(node):
gpu_from_host(kern))
out = tensor.patternbroadcast(out,
node.outputs[0].broadcastable)
out.values_eq_approx = values_eq_approx
out.values_eq_approx = values_eq_approx_high_tol
# in some case the ConvOp broadcast the last 2 dimensions
# differently then the gpu ConvOp
return [out]
......@@ -1291,7 +1292,7 @@ def local_gpu_conv(node):
out = tensor.patternbroadcast(
host_from_gpu(out),
node.outputs[0].broadcastable)
out.values_eq_approx = values_eq_approx
out.values_eq_approx = values_eq_approx_high_tol
# in some case the ConvOp broadcast the last 2 dimensions
# differently then the gpu ConvOp
return [out]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论