提交 64db99f7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Apply pyupgrade to theano.scan_module

上级 59a91c4b
......@@ -46,7 +46,6 @@ import logging
from collections import OrderedDict
import numpy as np
from six import integer_types
import theano.tensor as tt
from theano import compile, config, gof
......@@ -382,7 +381,7 @@ def scan(
# To do that we check here to see the nature of n_steps
n_fixed_steps = None
if isinstance(n_steps, (float, integer_types)):
if isinstance(n_steps, (float, int)):
n_fixed_steps = int(n_steps)
else:
try:
......
......@@ -55,7 +55,6 @@ import time
from collections import OrderedDict
import numpy as np
from six import integer_types, raise_from, string_types
import theano
from theano import compile, config, gof, gradient, tensor
......@@ -944,10 +943,9 @@ class Scan(PureOp):
profile = None
if theano.config.profile or (
isinstance(self.profile, (string_types, bool, integer_types))
and self.profile
isinstance(self.profile, (str, bool, (int,))) and self.profile
):
if isinstance(self.profile, string_types):
if isinstance(self.profile, str):
profile = ScanProfileStats(name=self.profile)
else:
profile = ScanProfileStats(name=self.name)
......@@ -1591,7 +1589,7 @@ class Scan(PureOp):
"'optimizer_excluding=scanOp_pushout_output' "
"to your Theano flags."
)
raise_from(ne, e)
raise ne from e
# 5.5 Copy over the values for nit_sot outputs
begin = end
......
......@@ -56,7 +56,6 @@ from collections import OrderedDict
from sys import maxsize
import numpy as np
from six import integer_types
import theano
from theano import gof, scalar, tensor
......@@ -249,7 +248,7 @@ class PushOutNonSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, clean_outputs)
local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in enumerate(clean_outputs)])
local_fgraph_outs_map = {v: k for k, v in enumerate(clean_outputs)}
to_remove_set = set()
to_replace_set = set()
......@@ -269,7 +268,7 @@ class PushOutNonSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v, k) for k, v in enumerate(inner_non_seqs)])
inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)}
outer_non_seqs = op.outer_non_seqs(node.inputs)
......@@ -348,7 +347,7 @@ class PushOutNonSeqScan(gof.Optimizer):
existent_nodes = [nd for nd in local_fgraph_topo if nd not in to_remove_set]
existent_nodes_set = set(existent_nodes)
to_keep_set = set([])
to_keep_set = set()
for nd in existent_nodes:
to_keep_set.update(nd.inputs)
......@@ -467,7 +466,7 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, clean_outputs)
local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in enumerate(clean_outputs)])
local_fgraph_outs_map = {v: k for k, v in enumerate(clean_outputs)}
to_remove_set = set()
to_replace_set = set()
......@@ -487,12 +486,12 @@ class PushOutSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v, k) for k, v in enumerate(inner_non_seqs)])
inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)}
outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs)
inner_seqs_map = dict([(v, k) for k, v in enumerate(inner_seqs)])
inner_seqs_map = {v: k for k, v in enumerate(inner_seqs)}
outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
......@@ -605,7 +604,7 @@ class PushOutSeqScan(gof.Optimizer):
existent_nodes = [nd for nd in local_fgraph_topo if nd not in to_remove_set]
existent_nodes_set = set(existent_nodes)
to_keep_set = set([])
to_keep_set = set()
for nd in existent_nodes:
to_keep_set.update(nd.inputs)
......@@ -1297,15 +1296,13 @@ class ScanSaveMem(gof.Optimizer):
if isinstance(stop, tensor.Variable):
global_nsteps["sym"] += [stop]
# not if it is maxsize
elif type(stop) in integer_types and stop == maxsize:
elif type(stop) == int and stop == maxsize:
global_nsteps = None
# yes if it is a int k, 0 < k < maxsize
elif (
type(stop) in integer_types and global_nsteps["real"] < stop
):
elif type(stop) == int and global_nsteps["real"] < stop:
global_nsteps["real"] = stop
# yes if it is a int k, 0 < k < maxsize
elif type(stop) in integer_types and stop > 0:
elif type(stop) == int and stop > 0:
pass
# not otherwise
else:
......
......@@ -24,7 +24,6 @@ import warnings
from collections import OrderedDict
import numpy as np
from six import string_types
import theano
from theano import compat, gof, scalar, tensor
......@@ -104,7 +103,7 @@ def safe_new(x, tag="", dtype=None):
return nw_x
class until(object):
class until:
"""
Class used to encode the different things the inner function of scan can
(or needs) to return.
......@@ -596,7 +595,7 @@ def isNaN_or_Inf_or_None(x):
try:
isNaN = np.isnan(x)
isInf = np.isinf(x)
isStr = isinstance(x, string_types)
isStr = isinstance(x, str)
except Exception:
isNaN = False
isInf = False
......@@ -609,7 +608,7 @@ def isNaN_or_Inf_or_None(x):
except Exception:
isNaN = False
isInf = False
if isinstance(x, gof.Constant) and isinstance(x.data, string_types):
if isinstance(x, gof.Constant) and isinstance(x.data, str):
isStr = True
else:
isStr = False
......@@ -687,7 +686,7 @@ def infer_shape(outs, inputs, input_shapes):
return ret
class Validator(object):
class Validator:
"""
Check if variables can be expressed without using variables in invalid.
......@@ -1007,7 +1006,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
return (nw_inputs, nw_outputs)
class scan_args(object):
class scan_args:
"""
Parses the inputs and outputs of scan in an easy to manipulate format.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论