提交 92e772f9 authored 作者: João Victor Risso's avatar João Victor Risso

Remove warning when using float32 and float16 in spatial transformer grad ops

上级 dda7258b
...@@ -2937,9 +2937,6 @@ class GpuDnnTransformerGradI(DnnBase): ...@@ -2937,9 +2937,6 @@ class GpuDnnTransformerGradI(DnnBase):
def __init__(self): def __init__(self):
DnnBase.__init__(self, ["c_code/dnn_sptf_gi.c"], "APPLY_SPECIFIC(dnn_sptf_gi)") DnnBase.__init__(self, ["c_code/dnn_sptf_gi.c"], "APPLY_SPECIFIC(dnn_sptf_gi)")
if theano.config.floatX == 'float16' or theano.config.floatX == 'float32':
warnings.warn(('GpuDnnTransformerGradI: computing gradients with float16 or '
'float32 might produce incorrect results due to lower precision.'))
def make_node(self, img, grid, dy, desc): def make_node(self, img, grid, dy, desc):
context_name = infer_context_name(img, grid, dy, desc) context_name = infer_context_name(img, grid, dy, desc)
...@@ -2985,9 +2982,6 @@ class GpuDnnTransformerGradT(DnnBase): ...@@ -2985,9 +2982,6 @@ class GpuDnnTransformerGradT(DnnBase):
def __init__(self): def __init__(self):
DnnBase.__init__(self, ["c_code/dnn_sptf_gt.c"], "APPLY_SPECIFIC(dnn_sptf_gt)") DnnBase.__init__(self, ["c_code/dnn_sptf_gt.c"], "APPLY_SPECIFIC(dnn_sptf_gt)")
if theano.config.floatX == 'float16' or theano.config.floatX == 'float32':
warnings.warn(('GpuDnnTransformerGradT: computing gradients with float16 or '
'float32 might produce incorrect results due to lower precision.'))
def make_node(self, dgrid, desc): def make_node(self, dgrid, desc):
context_name = infer_context_name(desc) context_name = infer_context_name(desc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论