提交 afb76ace authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5505 from abergeron/small_fixes

Mixed
...@@ -331,7 +331,8 @@ class NanGuardMode(Mode): ...@@ -331,7 +331,8 @@ class NanGuardMode(Mode):
def nan_check(node, thunk, storage_map, compute_map): def nan_check(node, thunk, storage_map, compute_map):
for var in node.outputs: for var in node.outputs:
if getattr(var.tag, 'nan_guard_mode_check', True): if (compute_map[var][0] and
getattr(var.tag, 'nan_guard_mode_check', True)):
do_check_on(storage_map[var][0], node) do_check_on(storage_map[var][0], node)
def nan_check_input(var, value): def nan_check_input(var, value):
......
...@@ -1170,6 +1170,7 @@ AddConfigVar('cmodule.age_thresh_use', ...@@ -1170,6 +1170,7 @@ AddConfigVar('cmodule.age_thresh_use',
def default_blas_ldflags(): def default_blas_ldflags():
global numpy global numpy
warn_record = []
try: try:
if (hasattr(numpy.distutils, '__config__') and if (hasattr(numpy.distutils, '__config__') and
numpy.distutils.__config__): numpy.distutils.__config__):
...@@ -1284,7 +1285,7 @@ def default_blas_ldflags(): ...@@ -1284,7 +1285,7 @@ def default_blas_ldflags():
import mkl # noqa import mkl # noqa
except ImportError as e: except ImportError as e:
if any([m for m in ('conda', 'Continuum') if m in sys.version]): if any([m for m in ('conda', 'Continuum') if m in sys.version]):
_logger.warning('install mkl with `conda install mkl-service`: %s', e) warn_record.append(('install mkl with `conda install mkl-service`: %s', e))
else: else:
# This branch is executed if no exception was raised # This branch is executed if no exception was raised
if sys.platform == "win32": if sys.platform == "win32":
...@@ -1327,6 +1328,13 @@ def default_blas_ldflags(): ...@@ -1327,6 +1328,13 @@ def default_blas_ldflags():
if res: if res:
return res return res
# If we are using conda and can't reuse numpy blas, then doing
# the fallback and test -lblas could give slow computation, so
# warn about this.
for warn in warn_record:
_logger.warning(*warn)
del warn_record
# Some environment don't have the lib dir in LD_LIBRARY_PATH. # Some environment don't have the lib dir in LD_LIBRARY_PATH.
# So add it. # So add it.
ret.extend(['-Wl,-rpath,' + l for l in ret.extend(['-Wl,-rpath,' + l for l in
......
...@@ -1093,6 +1093,9 @@ class T_subtensor(theano.tensor.tests.test_subtensor.T_subtensor): ...@@ -1093,6 +1093,9 @@ class T_subtensor(theano.tensor.tests.test_subtensor.T_subtensor):
self.assertTrue(val.ndim == data.ndim) self.assertTrue(val.ndim == data.ndim)
utt.assert_allclose(val, good) utt.assert_allclose(val, good)
def test_noncontiguous_idx(self):
raise SkipTest("test doesn't work here")
def test_advinc_subtensor1(): def test_advinc_subtensor1():
""" Test the second case in the opt local_gpu_advanced_incsubtensor1 """ """ Test the second case in the opt local_gpu_advanced_incsubtensor1 """
......
...@@ -529,6 +529,15 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -529,6 +529,15 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
utt.verify_grad(lambda m: m[idx], utt.verify_grad(lambda m: m[idx],
[data]) [data])
def test_noncontiguous_idx(self):
data = rand(4, 2, 3)
idx = [2, 2, 0, 0, 1, 1]
n = self.shared(data)
t = n[self.shared(numpy.asarray(idx))[::2]]
self.assertTrue(isinstance(t.owner.op, tensor.AdvancedSubtensor1))
val = self.eval_output_and_check(t, op_type=self.adv_sub1, length=2)
utt.assert_allclose(data[idx[::2]], val)
def test_err_invalid_list(self): def test_err_invalid_list(self):
n = self.shared(numpy.asarray(5, dtype=self.dtype)) n = self.shared(numpy.asarray(5, dtype=self.dtype))
self.assertRaises(TypeError, n.__getitem__, [0, 0]) self.assertRaises(TypeError, n.__getitem__, [0, 0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论