提交 5cd4e44e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for float16 for stability opts.

上级 f3e5604b
......@@ -5140,7 +5140,8 @@ def local_log_erfc(node):
T.log(1 - 1 / (2 * x ** 2) + 3 / (4 * x ** 4)
- 15 / (8 * x ** 6)))
if node.outputs[0].dtype == 'float32':
if (node.outputs[0].dtype == 'float32' or
node.outputs[0].dtype == 'float16'):
threshold = 10.0541949
elif node.outputs[0].dtype == 'float64':
threshold = 26.641747557
......@@ -5287,7 +5288,7 @@ def local_grad_log_erfc_neg(node):
3 / (4 * (x ** 4)) - 15 / (8 * (x ** 6)), -1)
* T.cast(T.sqrt(numpy.pi), dtype=x.dtype))
if x.dtype == 'float32':
if x.dtype == 'float32' or x.dtype == 'float16':
threshold = 9.3
#threshold = 10.1
elif x.dtype == 'float64':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论