提交 1e68d76b authored 作者: Frederic Bastien's avatar Frederic Bastien

Make the epsilon don't upcast when converting the abstract_bn to a theano graph.

上级 2cae7176
...@@ -627,6 +627,9 @@ def local_abstract_batch_norm_train(node): ...@@ -627,6 +627,9 @@ def local_abstract_batch_norm_train(node):
mean = x.mean(axes, keepdims=True) mean = x.mean(axes, keepdims=True)
var = x.var(axes, keepdims=True) var = x.var(axes, keepdims=True)
# The epsilon should not upcast the dtype.
if var.dtype == 'float32' and epsilon.dtype == 'float64':
epsilon = epsilon.astype('float32')
invstd = T.inv(T.sqrt(var + epsilon)) invstd = T.inv(T.sqrt(var + epsilon))
out = (x - mean) * (scale * invstd) + bias out = (x - mean) * (scale * invstd) + bias
results = [out, mean, invstd] results = [out, mean, invstd]
...@@ -702,6 +705,10 @@ def local_abstract_batch_norm_inference(node): ...@@ -702,6 +705,10 @@ def local_abstract_batch_norm_inference(node):
not isinstance(epsilon.type, TensorType): not isinstance(epsilon.type, TensorType):
return None return None
# The epsilon should not upcast the dtype.
if estimated_variance.dtype == 'float32' and epsilon.dtype == 'float64':
epsilon = epsilon.astype('float32')
result = (x - estimated_mean) * (scale / T.sqrt(estimated_variance + epsilon)) + bias result = (x - estimated_mean) * (scale / T.sqrt(estimated_variance + epsilon)) + bias
result = T.patternbroadcast(result, node.outputs[0].broadcastable) result = T.patternbroadcast(result, node.outputs[0].broadcastable)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论