提交 b70039a8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Fix bugbears caught by flake8-bugbear and related test updates

Most of the bugbears involved `assert False` and bad function keyword defaults. In tests, the former often signaled needed `pytest.raises` rewrites.
上级 0a499b81
......@@ -87,7 +87,7 @@ if version_data["error"] is not None:
lines = [l for l in lines if l.startswith("FALLBACK_VERSION")]
assert len(lines) == 1
FALLBACK_VERSION = lines[0].split("=")[1].strip().strip('""')
FALLBACK_VERSION = lines[0].split("=")[1].strip().strip('"')
version_data["version"] = FALLBACK_VERSION
......
......@@ -9,8 +9,8 @@ from theano.tensor.nnet import sigmoid
class NNet(object):
def __init__(
self,
input=tensor.dvector("input"),
target=tensor.dvector("target"),
input=None,
target=None,
n_input=1,
n_hidden=1,
n_output=1,
......@@ -19,6 +19,11 @@ class NNet(object):
):
super(NNet, self).__init__(**kw)
if input is None:
input = tensor.dvector("input")
if target is None:
target = tensor.dvector("target")
self.input = input
self.target = target
self.lr = shared(lr, "learning_rate")
......@@ -46,21 +51,20 @@ class NNet(object):
self.output_from_hidden = pfunc([self.hidden], self.output)
class TestNnet:
def test_nnet(self):
rng = np.random.RandomState(1827)
data = rng.rand(10, 4)
nnet = NNet(n_input=3, n_hidden=10)
for epoch in range(3):
mean_cost = 0
for x in data:
input = x[0:3]
target = x[3:]
output, cost = nnet.sgd_step(input, target)
mean_cost += cost
mean_cost /= float(len(data))
# print 'Mean cost at epoch %s: %s' % (epoch, mean_cost)
assert abs(mean_cost - 0.20588975452) < 1e-6
# Just call functions to make sure they do not crash.
nnet.compute_output(input)
nnet.output_from_hidden(np.ones(10))
def test_nnet():
rng = np.random.RandomState(1827)
data = rng.rand(10, 4)
nnet = NNet(n_input=3, n_hidden=10)
for epoch in range(3):
mean_cost = 0
for x in data:
input = x[0:3]
target = x[3:]
output, cost = nnet.sgd_step(input, target)
mean_cost += cost
mean_cost /= float(len(data))
# print 'Mean cost at epoch %s: %s' % (epoch, mean_cost)
assert abs(mean_cost - 0.20588975452) < 1e-6
# Just call functions to make sure they do not crash.
nnet.compute_output(input)
nnet.output_from_hidden(np.ones(10))
......@@ -73,13 +73,10 @@ class TestPfunc:
# Test that shared variables cannot be used as function inputs.
w_init = np.random.rand(2, 2)
w = shared(w_init.copy(), "w")
try:
with pytest.raises(
TypeError, match=r"^Cannot use a shared variable \(w\) as explicit input"
):
pfunc([w], theano.tensor.sum(w * w))
assert False
except TypeError as e:
msg = "Cannot use a shared variable (w) as explicit input"
if str(e).find(msg) < 0:
raise
def test_default_container(self):
# Ensure it is possible to (implicitly) use a shared variable in a
......
......@@ -224,20 +224,19 @@ class TestComputeTestValue:
# Since we have to inspect the traceback,
# we cannot simply use self.assertRaises()
try:
with pytest.raises(ValueError):
theano.scan(fn=fx, outputs_info=tt.ones_like(A), non_sequences=A, n_steps=k)
assert False
except ValueError:
# Get traceback
tb = sys.exc_info()[2]
frame_infos = traceback.extract_tb(tb)
# We should be in the "fx" function defined above
expected = "test_compute_test_value.py"
assert any(
(os.path.split(frame_info[0])[1] == expected and frame_info[2] == "fx")
for frame_info in frame_infos
), frame_infos
# Get traceback
tb = sys.exc_info()[2]
frame_infos = traceback.extract_tb(tb)
# We should be in the "fx" function defined above
expected = "test_compute_test_value.py"
assert any(
(os.path.split(frame_info[0])[1] == expected and frame_info[2] == "fx")
for frame_info in frame_infos
), frame_infos
@theano.change_flags(compute_test_value="raise")
def test_scan_err2(self):
......@@ -258,13 +257,10 @@ class TestComputeTestValue:
# Since we have to inspect the traceback,
# we cannot simply use self.assertRaises()
try:
with pytest.raises(ValueError, match="^could not broadcast input"):
theano.scan(
fn=fx, outputs_info=tt.ones_like(A.T), non_sequences=A, n_steps=k
)
assert False
except ValueError as e:
assert str(e).startswith("could not broadcast input"), str(e)
@theano.change_flags(compute_test_value="raise")
def test_no_c_code(self):
......
......@@ -59,7 +59,7 @@ def check_dtype_config_support(dtype, precision):
try:
f()
except RuntimeError as e:
assert "CUDNN_STATUS_ARCH_MISMATCH" in e.message
assert "CUDNN_STATUS_ARCH_MISMATCH" in str(e)
return False
return True
......
import itertools
import numpy as np
import pytest
import theano
from theano import config
......@@ -223,18 +222,6 @@ class TestGpuSger(TestGer):
self.gemm = gpugemm_inplace
super().setup_method()
@pytest.mark.skip(reason="0-sized objects not supported")
def test_f32_0_0(self):
assert False
@pytest.mark.skip(reason="0-sized objects not supported")
def test_f32_1_0(self):
assert False
@pytest.mark.skip(reason="0-sized objects not supported")
def test_f32_0_1(self):
assert False
class TestGpuSgerNoTransfer(TestGpuSger):
shared = staticmethod(gpuarray_shared_constructor)
......
......@@ -21,10 +21,13 @@ def set_theano_flags():
def compare_jax_and_py(
fgraph,
inputs,
assert_fn=partial(np.testing.assert_allclose, rtol=1e-4),
assert_fn=None,
simplify=False,
must_be_device_array=True,
):
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
if not simplify:
opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = theano.compile.mode.Mode(theano.sandbox.jax_linker.JAXLinker(), opts)
......@@ -308,7 +311,7 @@ def test_jax_scan():
test_input_vals = [np.array(10.0).astype(tt.config.floatX)]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
assert False
raise AssertionError()
def test_jax_Subtensors():
......
......@@ -8,6 +8,8 @@ If you do want to rewrite these tests, bear in mind:
* You don't need to use Composite.
* FunctionGraph and DualLinker are old, use compile.function instead.
"""
import pytest
import numpy as np
import theano
......@@ -432,11 +434,8 @@ class TestComplexMod:
def test_fail(self):
x = complex64()
y = int32()
try:
with pytest.raises(ComplexError):
x % y
assert False
except ComplexError:
pass
class TestDiv:
......
......@@ -717,9 +717,14 @@ class TestAddMul:
def _testSS(
self,
op,
array1=np.array([[1.0, 0], [3, 0], [0, 6]]),
array2=np.asarray([[0, 2.0], [0, 4], [5, 0]]),
array1=None,
array2=None,
):
if array1 is None:
array1 = np.array([[1.0, 0], [3, 0], [0, 6]])
if array2 is None:
array2 = np.asarray([[0, 2.0], [0, 4], [5, 0]])
for mtype1, mtype2 in product(_mtypes, _mtypes):
for dtype1, dtype2 in [
("float64", "int8"),
......@@ -757,9 +762,14 @@ class TestAddMul:
def _testSD(
self,
op,
array1=np.array([[1.0, 0], [3, 0], [0, 6]]),
array2=np.asarray([[0, 2.0], [0, 4], [5, 0]]),
array1=None,
array2=None,
):
if array1 is None:
array1 = np.array([[1.0, 0], [3, 0], [0, 6]])
if array2 is None:
array2 = np.asarray([[0, 2.0], [0, 4], [5, 0]])
for mtype in _mtypes:
for a in [
np.array(array1),
......@@ -810,9 +820,14 @@ class TestAddMul:
def _testDS(
self,
op,
array1=np.array([[1.0, 0], [3, 0], [0, 6]]),
array2=np.asarray([[0, 2.0], [0, 4], [5, 0]]),
array1=None,
array2=None,
):
if array1 is None:
array1 = np.array([[1.0, 0], [3, 0], [0, 6]])
if array2 is None:
array2 = np.asarray([[0, 2.0], [0, 4], [5, 0]])
for mtype in _mtypes:
for b in [
np.asarray(array2),
......
......@@ -51,7 +51,7 @@ def exec_multilayer_conv_nnet_old(
nkerns,
unroll_batch=0,
unroll_kern=0,
img=tt.dmatrix(),
img=None,
validate=True,
conv_op_py=False,
do_print=True,
......@@ -60,6 +60,8 @@ def exec_multilayer_conv_nnet_old(
unroll_patch_size=False,
verbose=0,
):
if img is None:
img = tt.dmatrix()
# build actual input images
imgval = global_rng.rand(bsize, imshp[0], imshp[1], imshp[2])
......@@ -180,13 +182,15 @@ def exec_multilayer_conv_nnet(
nkerns,
unroll_batch=0,
unroll_kern=0,
img=tt.dmatrix(),
img=None,
do_print=True,
repeat=1,
unroll_patch=False,
unroll_patch_size=False,
verbose=0,
):
if img is None:
img = tt.dmatrix()
# build actual input images
imgval = global_rng.rand(bsize, imshp[0], imshp[1], imshp[2])
......
import itertools
import logging
import operator
import os
import sys
import warnings
import builtins
......@@ -17,13 +15,11 @@ from tempfile import mkstemp
from copy import copy, deepcopy
from functools import partial, reduce
from six.moves import StringIO
from numpy.testing import assert_array_equal, assert_allclose, assert_almost_equal
from theano import change_flags
from theano.compat import exc_message, operator_div
from theano.compat import operator_div
from theano import compile, config, function, gof, shared
from theano.compile import DeepCopyOp
from theano.compile.mode import get_default_mode
......@@ -3585,31 +3581,12 @@ class TestMaxAndArgmax:
def test_basic_2_invalid(self):
n = as_tensor_variable(rand(2, 3))
# Silence expected error messages
_logger = logging.getLogger("theano.gof.opt")
oldlevel = _logger.level
_logger.setLevel(logging.CRITICAL)
try:
try:
eval_outputs(max_and_argmax(n, 3))
assert False
except ValueError:
pass
finally:
_logger.setLevel(oldlevel)
with pytest.raises(ValueError):
eval_outputs(max_and_argmax(n, 3))
def test_basic_2_invalid_neg(self):
n = as_tensor_variable(rand(2, 3))
old_stderr = sys.stderr
sys.stderr = StringIO()
try:
try:
eval_outputs(max_and_argmax(n, -3))
assert False
except ValueError:
pass
finally:
sys.stderr = old_stderr
with pytest.raises(ValueError):
eval_outputs(max_and_argmax(n, -3))
def test_basic_2_valid_neg(self):
n = as_tensor_variable(rand(2, 3))
......@@ -3831,32 +3808,10 @@ class TestArgminArgmax:
def test2_invalid(self):
for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]:
n = as_tensor_variable(rand(2, 3))
# Silence expected error messages
_logger = logging.getLogger("theano.gof.opt")
oldlevel = _logger.level
_logger.setLevel(logging.CRITICAL)
try:
try:
eval_outputs(fct(n, 3))
assert False
except ValueError:
pass
finally:
_logger.setLevel(oldlevel)
def test2_invalid_neg(self):
for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]:
n = as_tensor_variable(rand(2, 3))
old_stderr = sys.stderr
sys.stderr = StringIO()
try:
try:
eval_outputs(fct(n, -3))
assert False
except ValueError:
pass
finally:
sys.stderr = old_stderr
with pytest.raises(ValueError):
eval_outputs(fct(n, 3))
with pytest.raises(ValueError):
eval_outputs(fct(n, -3))
def test2_valid_neg(self):
for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]:
......@@ -3994,32 +3949,10 @@ class TestMinMax:
def test2_invalid(self):
for fct in [max, min]:
n = as_tensor_variable(rand(2, 3))
# Silence expected error messages
_logger = logging.getLogger("theano.gof.opt")
oldlevel = _logger.level
_logger.setLevel(logging.CRITICAL)
try:
try:
eval_outputs(fct(n, 3))
assert False
except ValueError:
pass
finally:
_logger.setLevel(oldlevel)
def test2_invalid_neg(self):
for fct in [max, min]:
n = as_tensor_variable(rand(2, 3))
old_stderr = sys.stderr
sys.stderr = StringIO()
try:
try:
eval_outputs(fct(n, -3))
assert False
except ValueError:
pass
finally:
sys.stderr = old_stderr
with pytest.raises(ValueError):
eval_outputs(fct(n, 3))
with pytest.raises(ValueError):
eval_outputs(fct(n, -3))
def test2_valid_neg(self):
for fct, nfct in [(max, np.max), (min, np.min)]:
......@@ -5659,62 +5592,20 @@ class TestDot:
self.cmp_dot(rand(4, 5, 6), rand(8, 6, 7))
def not_aligned(self, x, y):
ctv_backup = config.compute_test_value
config.compute_test_value = "off"
try:
with change_flags(compute_test_value="off"):
z = dot(x, y)
finally:
config.compute_test_value = ctv_backup
# constant folding will complain to _logger that things are not aligned
# this is normal, testers are not interested in seeing that output.
_logger = logging.getLogger("theano.gof.opt")
oldlevel = _logger.level
_logger.setLevel(logging.CRITICAL)
try:
try:
eval_outputs([z])
assert False # should have raised exception
except ValueError as e:
e0 = exc_message(e)
assert (
# Reported by numpy.
e0.split()[1:4] == ["are", "not", "aligned"]
or
# Reported by blas or Theano.
e0.split()[0:2] == ["Shape", "mismatch:"]
or
# Reported by Theano perform
(e0.split()[0:4] == ["Incompatible", "shapes", "for", "gemv"])
or e
)
finally:
_logger.setLevel(oldlevel)
with pytest.raises(ValueError):
eval_outputs([z])
def test_align_1_1(self):
def test_not_aligned(self):
self.not_aligned(rand(5), rand(6))
def test_align_1_2(self):
self.not_aligned(rand(5), rand(6, 4))
def test_align_1_3(self):
self.not_aligned(rand(5), rand(6, 4, 7))
def test_align_2_1(self):
self.not_aligned(rand(5, 4), rand(6))
def test_align_2_2(self):
self.not_aligned(rand(5, 4), rand(6, 7))
def test_align_2_3(self):
self.not_aligned(rand(5, 4), rand(6, 7, 8))
def test_align_3_1(self):
self.not_aligned(rand(5, 4, 3), rand(6))
def test_align_3_2(self):
self.not_aligned(rand(5, 4, 3), rand(6, 7))
def test_align_3_3(self):
self.not_aligned(rand(5, 4, 3), rand(6, 7, 8))
def test_grad(self):
......@@ -7459,11 +7350,14 @@ def _test_autocast_numpy_floatX():
class TestArithmeticCast:
# Test output types of basic arithmeric operations (* / + - //).
#
# We only test the behavior for `config.cast_policy` set to either 'numpy' or
# 'numpy+floatX': the 'custom' behavior is (at least partially) tested in
# `_test_autocast_custom`.
"""Test output types of basic arithmeric operations (* / + - //).
We only test the behavior for `config.cast_policy` set to either 'numpy' or
'numpy+floatX': the 'custom' behavior is (at least partially) tested in
`_test_autocast_custom`.
"""
def test_arithmetic_cast(self):
backup_config = config.cast_policy
dtypes = get_numeric_types(with_complex=True)
......@@ -7625,7 +7519,7 @@ class TestArithmeticCast:
pytest.skip("Known issue with" "numpy see #761")
# In any other situation: something wrong is
# going on!
assert False
raise AssertionError()
finally:
config.cast_policy = backup_config
if config.int_division == "int":
......@@ -7861,52 +7755,32 @@ def test_unalign():
assert not b.flags.aligned
a[:] = rand(len(a))
b[:] = rand(len(b))
out_numpy = 2 * a + 3 * b
# out_numpy = 2 * a + 3 * b
av, bv = tt.vectors("ab")
f = theano.function([av, bv], 2 * av + 3 * bv)
f.maker.fgraph.toposort()
try:
out_theano = f(a, b)
assert not a.flags.aligned
assert not b.flags.aligned
assert np.allclose(out_numpy, out_theano)
assert False
except TypeError:
pass
with pytest.raises(TypeError):
f(a, b)
a = np.empty((), dtype=dtype)["f1"]
b = np.empty((), dtype=dtype)["f1"]
assert not a.flags.aligned
assert not b.flags.aligned
out_numpy = 2 * a + 3 * b
# out_numpy = 2 * a + 3 * b
av, bv = tt.scalars("ab")
f = theano.function([av, bv], 2 * av + 3 * bv)
f.maker.fgraph.toposort()
try:
out_theano = f(a, b)
assert not a.flags.aligned
assert not b.flags.aligned
assert np.allclose(out_numpy, out_theano)
assert False
except TypeError:
pass
with pytest.raises(TypeError):
f(a, b)
def test_dimshuffle_duplicate():
x = tt.vector()
success = False
try:
with pytest.raises(ValueError, match="may not appear twice"):
tt.DimShuffle((False,), (0, 0))(x)
except ValueError as e:
assert str(e).find("may not appear twice") != -1
success = True
assert success
class TestGetScalarConstantValue:
......@@ -8033,15 +7907,11 @@ class TestGetScalarConstantValue:
assert e == 3, (c, d, e)
class TestComplexMod:
def test_complex_mod_failure():
# Make sure % fails on complex numbers.
def test_fail(self):
x = vector(dtype="complex64")
try:
x % 5
assert False
except theano.scalar.ComplexError:
pass
x = vector(dtype="complex64")
with pytest.raises(theano.scalar.ComplexError):
x % 5
class TestSize:
......
......@@ -585,11 +585,6 @@ class TestRealMatrix:
assert not _is_real_matrix(tt.DimShuffle([False], ["x", 0])(tt.dvector()))
def fail(msg):
print("FAIL", msg)
assert False
"""
This test suite ensures that Gemm is inserted where it belongs, and
that the resulting functions compute the same things as the originals.
......@@ -600,9 +595,10 @@ def XYZab():
return tt.matrix(), tt.matrix(), tt.matrix(), tt.scalar(), tt.scalar()
def just_gemm(
i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0, expected_nb_gemm=1
):
def just_gemm(i, o, ishapes=None, max_graphlen=0, expected_nb_gemm=1):
if ishapes is None:
ishapes = [(4, 3), (3, 5), (4, 5), (), ()]
f = inplace_func(
[In(ii, mutable=True, allow_downcast=True) for ii in i],
o,
......
......@@ -240,11 +240,8 @@ def test_inverse_singular():
singular = np.array([[1, 0, 0]] + [[0, 1, 0]] * 2, dtype=theano.config.floatX)
a = tensor.matrix()
f = function([a], matrix_inverse(a))
try:
with pytest.raises(np.linalg.LinAlgError):
f(singular)
except np.linalg.LinAlgError:
return
assert False
def test_inverse_grad():
......
......@@ -1081,11 +1081,6 @@ class TestCanonize:
assert len(topo[0].inputs) == 1
assert out_dtype == out.dtype
@pytest.mark.xfail(reason="Not implemented yet")
def test_dont_merge_if_multiple_client(self):
# test those case take from the comment in Canonizer
assert False
def test_canonicalize_nan(self):
# Regression test for bug in canonicalization of NaN values.
# This bug caused an infinite loop which was caught by the equilibrium
......@@ -1202,10 +1197,7 @@ def test_cast_in_mul_canonizer():
)
== 0
)
assert (
len([n for n in nodes if isinstance(getattr(n.op, "scalar_op"), scal.Cast)])
== 1
)
assert len([n for n in nodes if isinstance(n.op.scalar_op, scal.Cast)]) == 1
f([1], [1])
......@@ -2332,11 +2324,8 @@ def test_local_useless_inc_subtensor():
assert (out == np.asarray([[3, 4]])[::, sub]).all()
# Test that we don't remove shape error
try:
with pytest.raises(ValueError):
f([[2, 3]], [[3, 4], [4, 5]])
assert False
except (ValueError, AssertionError):
pass
# Test that we don't remove broadcastability
out = f([[2, 3], [3, 4]], [[5, 6]])
......
......@@ -25,7 +25,7 @@ class TestConfig:
configparam=ConfigParam("invalid", filter=filter),
in_c_key=False,
)
assert False
raise AssertionError()
except ValueError:
pass
......
......@@ -497,15 +497,7 @@ def test_known_grads():
full = full(*values)
assert len(true_grads) == len(full)
for a, b, var in zip(true_grads, full, inputs):
if not np.allclose(a, b):
print("Failure")
print(a)
print(b)
print(var)
print(layer)
for v in known:
print(v, ":", theano.function(inputs, known[v])(*values))
assert False
assert np.allclose(a, b)
def test_dxdx():
......
......@@ -2733,7 +2733,7 @@ class DebugMode(Mode):
check_isfinite=None,
check_preallocated_output=None,
require_matching_strides=None,
linker=_DummyLinker(),
linker=None,
):
"""
If any of these arguments (except optimizer) is not None, it overrides
......@@ -2741,6 +2741,8 @@ class DebugMode(Mode):
allow Mode.requiring() and some other fct to work with DebugMode too.
"""
if linker is None:
linker = _DummyLinker()
if not isinstance(linker, _DummyLinker):
raise Exception(
......
......@@ -3294,7 +3294,7 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
]
# if ops_to_check is a function
elif hasattr(ops_to_check, "__call__"):
elif callable(ops_to_check):
apply_nodes_to_check = [
node for node in fgraph.apply_nodes if ops_to_check(node)
]
......
import copy
import sys
import math
from theano.compat import DefaultOrderedDict
from theano.misc.ordered_set import OrderedSet
......@@ -184,7 +185,7 @@ class Query(object):
require=None,
exclude=None,
subquery=None,
position_cutoff=float("inf"),
position_cutoff=math.inf,
extra_optimizations=None,
):
self.include = OrderedSet(include)
......
......@@ -11,7 +11,7 @@ from theano import config
from theano.compat import PY3
def simple_extract_stack(f=None, limit=None, skips=[]):
def simple_extract_stack(f=None, limit=None, skips=None):
"""This is traceback.extract_stack from python 2.7 with this change:
- Comment the update of the cache.
......@@ -27,6 +27,9 @@ def simple_extract_stack(f=None, limit=None, skips=[]):
When we find one level that isn't skipped, we stop skipping.
"""
if skips is None:
skips = []
if f is None:
try:
raise ZeroDivisionError
......
......@@ -1125,7 +1125,7 @@ class VM_Linker(link.LocalLinker):
compute_map_re[var][0] = 1
if getattr(fgraph.profile, "dependencies", None):
dependencies = getattr(fgraph.profile, "dependencies")
dependencies = fgraph.profile.dependencies
else:
dependencies = self.compute_gc_dependencies(storage_map)
......
......@@ -34,7 +34,7 @@ try:
from skcuda import fft
skcuda_available = True
except (ImportError, Exception):
except Exception:
skcuda_available = False
......
......@@ -1669,12 +1669,14 @@ def local_gpua_assert_graph(op, context_name, inputs, outputs):
@op_lifter([ConvOp])
@register_opt2([ConvOp], "fast_compile")
def local_gpua_error_convop(op, context_name, inputs, outputs):
assert False, """
raise AssertionError(
"""
ConvOp does not work with the gpuarray backend.
Use the new convolution interface to have GPU convolution working:
theano.tensor.nnet.conv2d()
"""
)
@register_opt("fast_compile")
......
......@@ -24,7 +24,7 @@ def render_string(string, sub):
if str(F) == str(E):
raise Exception(string[0:i] + "<<<< caused exception " + str(F))
i += 1
assert False
raise AssertionError()
return finalCode
......
......@@ -1345,7 +1345,7 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
name = position_independent_str(obj)
if " at 0x" in name:
print(name)
assert False
raise AssertionError()
prefix = cur_tag + "="
......
......@@ -4,9 +4,9 @@ from theano.tensor.basic import Join
def scan_checkpoints(
fn,
sequences=[],
sequences=None,
outputs_info=None,
non_sequences=[],
non_sequences=None,
name="checkpointscan_fn",
n_steps=None,
save_every_N=10,
......@@ -91,11 +91,17 @@ def scan_checkpoints(
"""
# Standardize the format of input arguments
if not isinstance(sequences, list):
if sequences is None:
sequences = []
elif not isinstance(sequences, list):
sequences = [sequences]
if not isinstance(outputs_info, list):
outputs_info = [outputs_info]
if not isinstance(non_sequences, list):
if non_sequences is None:
non_sequences = []
elif not isinstance(non_sequences, list):
non_sequences = [non_sequences]
# Check that outputs_info has no taps:
......
......@@ -246,7 +246,7 @@ def clone(
return outs
def map_variables(replacer, graphs, additional_inputs=[]):
def map_variables(replacer, graphs, additional_inputs=None):
"""Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'.
......@@ -277,6 +277,8 @@ def map_variables(replacer, graphs, additional_inputs=[]):
# v is now equal to a * b + c
"""
if additional_inputs is None:
additional_inputs = []
# wrap replacer to avoid replacing things we just put there.
graphs_seen = set()
......
......@@ -860,9 +860,7 @@ def multinomial_helper(random_state, n, pvals, size):
return out
def multinomial(
random_state, size=None, n=1, pvals=[0.5, 0.5], ndim=None, dtype="int64"
):
def multinomial(random_state, size=None, n=1, pvals=None, ndim=None, dtype="int64"):
"""
Sample from one or more multinomial distributions defined by
one-dimensional slices in pvals.
......@@ -923,6 +921,8 @@ def multinomial(
draws.
"""
if pvals is None:
pvals = [0.5, 0.5]
n = tensor.as_tensor_variable(n)
pvals = tensor.as_tensor_variable(pvals)
# until ellipsis is implemented (argh)
......@@ -1056,7 +1056,7 @@ class RandomStreamsBase(object):
"""
return self.gen(permutation, size, n, ndim=ndim, dtype=dtype)
def multinomial(self, size=None, n=1, pvals=[0.5, 0.5], ndim=None, dtype="int64"):
def multinomial(self, size=None, n=1, pvals=None, ndim=None, dtype="int64"):
"""
Sample n times from a multinomial distribution defined by
probabilities pvals, as many times as required by size. For
......@@ -1072,6 +1072,8 @@ class RandomStreamsBase(object):
Note that the output will then be of dimension ndim+1.
"""
if pvals is None:
pvals = [0.5, 0.5]
return self.gen(multinomial, size, n, pvals, ndim=ndim, dtype=dtype)
def shuffle_row_elements(self, input):
......
......@@ -113,7 +113,7 @@ class TensorType(Type):
("%s expected a ndarray object with " "dtype = %s (got %s).")
% (self, self.numpy_dtype, data.dtype)
)
assert False, "This point should never be reached."
raise AssertionError("This point should never be reached.")
else:
if allow_downcast:
# Convert to self.dtype, regardless of the type of data
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论