提交 d07ce6a9 authored 作者: Frederic's avatar Frederic

Allow warn_float64 to be ignore, warn, raise or pdb.

上级 e42c3a8c
...@@ -28,7 +28,7 @@ AddConfigVar('warn_float64', ...@@ -28,7 +28,7 @@ AddConfigVar('warn_float64',
"If True, warn when a tensor variable with float64 dtype is" "If True, warn when a tensor variable with float64 dtype is"
" created. They can't be run on the GPU with the current(old)" " created. They can't be run on the GPU with the current(old)"
" gpu back-end and are slow with gamer GPUs.", " gpu back-end and are slow with gamer GPUs.",
BoolParam(False), EnumStr('ignore', 'warn', 'raise', 'pdb'),
in_c_key=False, in_c_key=False,
) )
......
import copy import copy
import pdb
import sys
import traceback as tb
import warnings import warnings
import numpy import numpy
...@@ -579,28 +582,31 @@ class TensorVariable(_tensor_py_operators, Variable): ...@@ -579,28 +582,31 @@ class TensorVariable(_tensor_py_operators, Variable):
def __init__(self, type, owner=None, index=None, name=None): def __init__(self, type, owner=None, index=None, name=None):
super(TensorVariable, self).__init__(type, owner=owner, super(TensorVariable, self).__init__(type, owner=owner,
index=index, name=name) index=index, name=name)
if (config.warn_float64 and type.dtype == 'float64'): if (config.warn_float64 != 'ignore' and type.dtype == 'float64'):
# Get the user stack. We don't want function inside the msg = ('You are creating a TensorVariable '
# tensor and gof directory to be shown to the user. 'with float64 dtype. You requested this warning via '
import traceback as tb 'the Theano flag warn_float64=True.')
x = tb.extract_stack() if config.warn_float64 == "warn":
nb_rm = 0 # Get the user stack. We don't want function inside the
while x: # tensor and gof directory to be shown to the user.
file_path = x[-1][0] x = tb.extract_stack()
rm = False nb_rm = 0
for p in ["theano/tensor/", while x:
"theano/gof/"]: file_path = x[-1][0]
if p in file_path: rm = False
x = x[:-1] for p in ["theano/tensor/",
nb_rm += 1 "theano/gof/"]:
rm = True if p in file_path:
if not rm: x = x[:-1]
break nb_rm += 1
warnings.warn( rm = True
'Warning, you are creating a TensorVariable ' if not rm:
'with float64 dtype. You requested this warning via ' break
'the Theano flag warn_float64=True.', stacklevel=1 + nb_rm) warnings.warn(msg, stacklevel=1 + nb_rm)
elif config.warn_float64 == "raise":
raise Exception(msg)
elif config.warn_float64 == 'pdb':
import pdb;pdb.set_trace()
TensorType.Variable = TensorVariable TensorType.Variable = TensorVariable
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论