提交 98d7817f authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Move ref_cast to config.py.

上级 e3e0e6d4
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import theano.tensor
import theano.gpuarray import theano.gpuarray
if theano.gpuarray.pygpu is None: if theano.gpuarray.pygpu is None:
...@@ -21,3 +22,10 @@ if theano.config.mode == 'FAST_COMPILE': ...@@ -21,3 +22,10 @@ if theano.config.mode == 'FAST_COMPILE':
else: else:
mode_with_gpu = theano.compile.mode.get_default_mode().including('gpuarray').excluding('gpu') mode_with_gpu = theano.compile.mode.get_default_mode().including('gpuarray').excluding('gpu')
mode_without_gpu = theano.compile.mode.get_default_mode().excluding('gpuarray') mode_without_gpu = theano.compile.mode.get_default_mode().excluding('gpuarray')
# If using float16, cast reference input to float32
def ref_cast(x):
if x.type.dtype == 'float16':
x = theano.tensor.cast(x, 'float32')
return x
...@@ -17,7 +17,7 @@ from .. import dnn ...@@ -17,7 +17,7 @@ from .. import dnn
from ..basic_ops import GpuAllocEmpty from ..basic_ops import GpuAllocEmpty
from ..type import gpuarray_shared_constructor from ..type import gpuarray_shared_constructor
from .config import mode_with_gpu, mode_without_gpu, test_ctx_name from .config import mode_with_gpu, mode_without_gpu, test_ctx_name, ref_cast
from . import test_nnet from . import test_nnet
from .rnn_support import Model, GRU, LSTM, WrapperLayer from .rnn_support import Model, GRU, LSTM, WrapperLayer
...@@ -33,13 +33,6 @@ def set_precision(floatX): ...@@ -33,13 +33,6 @@ def set_precision(floatX):
return precision return precision
# If using float16, cast reference input to float32
def ref_cast(x):
if theano.config.floatX == 'float16':
x = T.cast(x, 'float32')
return x
def test_dnn_conv_desc_merge(): def test_dnn_conv_desc_merge():
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论