提交 33c97605 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3627 from sieben/simplify_comparaison

Simplify comparisons
...@@ -105,7 +105,7 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) { ...@@ -105,7 +105,7 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
" This isn't supported anymore." " This isn't supported anymore."
" Update to CuDNN v2 final version.") " Update to CuDNN v2 final version.")
raise RuntimeError(dnn_available.msg) raise RuntimeError(dnn_available.msg)
if v[0] >= 3000 and v[0] < 3007: if 3000 <= v[0] < 3007:
# 3007 is the final release of cudnn v3 # 3007 is the final release of cudnn v3
dnn_available.avail = False dnn_available.avail = False
dnn_available.msg = ( dnn_available.msg = (
......
...@@ -74,7 +74,7 @@ def _dnn_check_version(): ...@@ -74,7 +74,7 @@ def _dnn_check_version():
"You have an old release of CuDNN (or a release candidate) " "You have an old release of CuDNN (or a release candidate) "
"that isn't supported. Please update to at least v2 final " "that isn't supported. Please update to at least v2 final "
"version.") "version.")
if v >= 3000 and v < 3007: if 3000 <= v < 3007:
return False, ( return False, (
"You have installed a release candidate of CuDNN v3. This " "You have installed a release candidate of CuDNN v3. This "
"isn't supported. Please update to v3 final version.") "isn't supported. Please update to v3 final version.")
......
...@@ -1444,8 +1444,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1444,8 +1444,7 @@ class ScanSaveMem(gof.Optimizer):
last_sitsot_idx = (node.op.n_mit_mot + last_sitsot_idx = (node.op.n_mit_mot +
node.op.n_mit_sot + node.op.n_mit_sot +
node.op.n_sit_sot - 1) node.op.n_sit_sot - 1)
preallocable_output = (i >= first_mitsot_idx and preallocable_output = (first_mitsot_idx <= i <= last_sitsot_idx)
i <= last_sitsot_idx)
if (prealloc_outs and preallocable_output): if (prealloc_outs and preallocable_output):
pval = select_max(nw_steps - start + init_l[i], pval = select_max(nw_steps - start + init_l[i],
......
...@@ -117,7 +117,7 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5, gap=None, ...@@ -117,7 +117,7 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5, gap=None,
if out_dtype is None: if out_dtype is None:
out_dtype = theano.config.floatX out_dtype = theano.config.floatX
assert 0 <= p and p <= 1 assert 0 <= p <= 1
assert len(shape) == 2 assert len(shape) == 2
assert out_dtype in sparse.all_dtypes assert out_dtype in sparse.all_dtypes
assert gap is None or isinstance(gap, (tuple, list)) assert gap is None or isinstance(gap, (tuple, list))
......
...@@ -432,7 +432,7 @@ def constant(x, name=None, ndim=None, dtype=None): ...@@ -432,7 +432,7 @@ def constant(x, name=None, ndim=None, dtype=None):
return ret return ret
sig = ret.signature() sig = ret.signature()
if (sig not in constant_cache and ret.data.size == 1 and if (sig not in constant_cache and ret.data.size == 1 and
ret.data <= 10 and ret.data >= -10 and (-10) <= ret.data <= 10 and
(ret.dtype in int_dtypes or ret.dtype in uint_dtypes or (ret.dtype in int_dtypes or ret.dtype in uint_dtypes or
(ret.dtype in float_dtypes and int(ret.data) == ret.data))): (ret.dtype in float_dtypes and int(ret.data) == ret.data))):
constant_cache[sig] = ret constant_cache[sig] = ret
......
...@@ -527,7 +527,7 @@ class Elemwise(OpenMPOp): ...@@ -527,7 +527,7 @@ class Elemwise(OpenMPOp):
self.nfunc = None self.nfunc = None
if getattr(self, 'nfunc_spec', None): if getattr(self, 'nfunc_spec', None):
self.nfunc = getattr(numpy, self.nfunc_spec[0]) self.nfunc = getattr(numpy, self.nfunc_spec[0])
elif self.scalar_op.nin > 0 and self.scalar_op.nin < 32: elif 0 < self.scalar_op.nin < 32:
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.ufunc = numpy.frompyfunc(self.scalar_op.impl,
self.scalar_op.nin, self.scalar_op.nin,
self.scalar_op.nout) self.scalar_op.nout)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论