提交 db7bffb4 authored 作者: abergeron's avatar abergeron

Merge pull request #2056 from nouiz/Tanjay94-TV

warn float64
......@@ -184,6 +184,15 @@ import theano and print the config variable, as in:
and similar functions. It also sets the default theano bit width for
arguments passed as Python floating-point numbers.
.. attribute:: warn_float64
String value: either 'ignore', 'warn', 'raise' or 'pdb'
Default: 'float64'
When creating a TensorVariable with dtype float64, what should be done?
This is useful to help find upcast to float64 in user code.
.. attribute:: allow_gc
Bool value: either ``True`` or ``False``
......
......@@ -24,6 +24,14 @@ AddConfigVar('floatX',
EnumStr('float64', 'float32', convert=floatX_convert,),
)
AddConfigVar('warn_float64',
"Do an action when a tensor variable with float64 dtype is"
" created. They can't be run on the GPU with the current(old)"
" gpu back-end and are slow with gamer GPUs.",
EnumStr('ignore', 'warn', 'raise', 'pdb'),
in_c_key=False,
)
AddConfigVar('cast_policy',
"Rules for implicit type casting",
EnumStr('custom', 'numpy+floatX',
......
import copy
import pdb
import sys
import traceback as tb
import warnings
import numpy
......@@ -9,6 +13,7 @@ from theano.gof import Constant, Variable
from theano.gof.utils import hashtype
from theano.tensor.utils import hash_from_ndarray
from theano.tensor.type import TensorType
from theano.configparser import config
class AsTensorError(TypeError):
......@@ -574,6 +579,34 @@ class _tensor_py_operators:
class TensorVariable(_tensor_py_operators, Variable):
"""Subclass to add the tensor operators to the basic `Variable` class."""
def __init__(self, type, owner=None, index=None, name=None):
super(TensorVariable, self).__init__(type, owner=owner,
index=index, name=name)
if (config.warn_float64 != 'ignore' and type.dtype == 'float64'):
msg = ('You are creating a TensorVariable '
'with float64 dtype. You requested an action via '
'the Theano flag warn_float64={ignore,warn,raise,pdb}.')
if config.warn_float64 == "warn":
# Get the user stack. We don't want function inside the
# tensor and gof directory to be shown to the user.
x = tb.extract_stack()
nb_rm = 0
while x:
file_path = x[-1][0]
rm = False
for p in ["theano/tensor/",
"theano/gof/"]:
if p in file_path:
x = x[:-1]
nb_rm += 1
rm = True
if not rm:
break
warnings.warn(msg, stacklevel=1 + nb_rm)
elif config.warn_float64 == "raise":
raise Exception(msg)
elif config.warn_float64 == 'pdb':
import pdb;pdb.set_trace()
TensorType.Variable = TensorVariable
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论