提交 044f52ec authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Apply pyupgrade to top-level modules in tests sub-package

上级 021ed662
......@@ -18,7 +18,7 @@ class MismatchError(Exception):
"""
class Record(object):
class Record:
"""
Records a sequence of strings (from a string buffer). These can then be
compared to another sequence of strings, and if the two sequences don't
......@@ -67,7 +67,7 @@ class Record(object):
assert file_object is not None or file_path is not None
if replay and file_object is None:
self.f = open(file_path, "r")
self.f = open(file_path)
elif (not replay) and file_object is None:
self.f = open(file_path, "w")
else:
......@@ -157,7 +157,7 @@ class RecordMode(Mode):
"""
self.record = record
self.known_fgraphs = set([])
self.known_fgraphs = set()
def __init__(self, record=None, **kwargs):
"""
......@@ -262,4 +262,4 @@ class RecordMode(Mode):
linker = theano.gof.vm.VM_Linker(use_cloop=bool(theano.config.cxx))
wrap_linker = theano.gof.WrapLinkerMany([linker], [callback])
super(RecordMode, self).__init__(wrap_linker, optimizer="fast_run")
super().__init__(wrap_linker, optimizer="fast_run")
......@@ -47,7 +47,6 @@ class TestGradSourcesInputs:
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
pass
a = retNone().make_node()
with pytest.raises(TypeError):
......
......@@ -114,7 +114,7 @@ class RopLopChecker:
v1 = rop_f(vx, vv)
v2 = scan_f(vx, vv)
assert np.allclose(v1, v2), "ROP mismatch: %s %s" % (v1, v2)
assert np.allclose(v1, v2), "ROP mismatch: {} {}".format(v1, v2)
self.check_nondiff_rop(theano.clone(y, replace={self.mx: break_op(self.mx)}))
......@@ -127,7 +127,7 @@ class RopLopChecker:
v1 = lop_f(vx, vv)
v2 = scan_f(vx, vv)
assert np.allclose(v1, v2), "LOP mismatch: %s %s" % (v1, v2)
assert np.allclose(v1, v2), "LOP mismatch: {} {}".format(v1, v2)
def check_rop_lop(self, y, out_shape):
"""
......@@ -151,7 +151,7 @@ class RopLopChecker:
v1 = rop_f(vx, vv)
v2 = scan_f(vx, vv)
assert np.allclose(v1, v2), "ROP mismatch: %s %s" % (v1, v2)
assert np.allclose(v1, v2), "ROP mismatch: {} {}".format(v1, v2)
try:
tensor.Rop(
......@@ -179,7 +179,7 @@ class RopLopChecker:
v1 = lop_f(vx, vv)
v2 = scan_f(vx, vv)
assert np.allclose(v1, v2), "LOP mismatch: %s %s" % (v1, v2)
assert np.allclose(v1, v2), "LOP mismatch: {} {}".format(v1, v2)
class TestRopLop(RopLopChecker):
......@@ -297,7 +297,7 @@ class TestRopLop(RopLopChecker):
scan_f = function([], sy, on_unused_input="ignore", mode=mode)
v1 = rop_f()
v2 = scan_f()
assert np.allclose(v1, v2), "Rop mismatch: %s %s" % (v1, v2)
assert np.allclose(v1, v2), "Rop mismatch: {} {}".format(v1, v2)
def test_conv(self):
for conv_op in [conv.conv2d, conv2d]:
......@@ -353,7 +353,7 @@ class TestRopLop(RopLopChecker):
ev_filter_data = np.random.random(filter_shape).astype(dtype)
v1 = rop_f(image_data, filter_data, ev_image_data, ev_filter_data)
v2 = scan_f(image_data, filter_data, ev_image_data, ev_filter_data)
assert np.allclose(v1, v2), "Rop mismatch: %s %s" % (v1, v2)
assert np.allclose(v1, v2), "Rop mismatch: {} {}".format(v1, v2)
def test_join(self):
tv = np.asarray(self.rng.uniform(size=(10,)), theano.config.floatX)
......
......@@ -5,7 +5,6 @@ from functools import wraps
import numpy as np
import pytest
from six import integer_types
import theano
import theano.tensor as tt
......@@ -110,7 +109,7 @@ class MockRandomState:
return out + maxval - 1
class OptimizationTestMixin(object):
class OptimizationTestMixin:
def assertFunctionContains(self, f, op, min=1, max=sys.maxsize):
toposort = f.maker.fgraph.toposort()
matches = [node for node in toposort if node.op == op]
......@@ -148,7 +147,7 @@ class OptimizationTestMixin(object):
return self.assertFunctionContainsClass(f, op, min=N, max=N)
class OpContractTestMixin(object):
class OpContractTestMixin:
# self.ops should be a list of instantiations of an Op class to test.
# self.other_op should be an op which is different from every op
other_op = tt.add
......@@ -236,7 +235,7 @@ class InferShapeTester:
mode = mode.excluding(*excluding)
if warn:
for var, inp in zip(inputs, numeric_inputs):
if isinstance(inp, (integer_types, float, list, tuple)):
if isinstance(inp, (int, float, list, tuple)):
inp = var.type.filter(inp)
if not hasattr(inp, "shape"):
continue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论