提交 19911b2e authored 作者: João Victor Risso's avatar João Victor Risso

Add warning in spatial transformer gradients when using lower precision floating-point types

上级 781d87bd
......@@ -2937,6 +2937,9 @@ class GpuDnnTransformerGradI(DnnBase):
def __init__(self):
DnnBase.__init__(self, ["c_code/dnn_sptf_gi.c"], "APPLY_SPECIFIC(dnn_sptf_gi)")
if theano.config.floatX == 'float16' or theano.config.floatX == 'float32':
warnings.warn(('GpuDnnTransformerGradI: computing gradients with float16 or '
'float32 might produce incorrect results due to lower precision.'))
def make_node(self, img, grid, dy, desc):
context_name = infer_context_name(img, grid, dy, desc)
......@@ -2982,6 +2985,9 @@ class GpuDnnTransformerGradT(DnnBase):
def __init__(self):
DnnBase.__init__(self, ["c_code/dnn_sptf_gt.c"], "APPLY_SPECIFIC(dnn_sptf_gt)")
if theano.config.floatX == 'float16' or theano.config.floatX == 'float32':
warnings.warn(('GpuDnnTransformerGradT: computing gradients with float16 or '
'float32 might produce incorrect results due to lower precision.'))
def make_node(self, dgrid, desc):
context_name = infer_context_name(desc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论