提交 1b8fa84e authored 作者: carriepl's avatar carriepl

Flake8 on scan_opt.py

上级 2928d02a
...@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global). ...@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
in2out(scan_merge_inouts), in2out(scan_merge_inouts),
ScanSaveMem, ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3) in2out(remove_constants_and_unused_inputs_scan3)
""" """
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import logging import logging
import copy import copy
...@@ -66,21 +57,29 @@ import numpy ...@@ -66,21 +57,29 @@ import numpy
import theano import theano
from theano import tensor from theano import tensor
from theano.tensor import opt, get_scalar_constant_value from theano.tensor import opt, get_scalar_constant_value, Alloc, AllocEmpty
from theano import gof from theano import gof
from theano.compat import OrderedDict from theano.compat import OrderedDict
from six import integer_types, iteritems from six import integer_types, iteritems
from six.moves import xrange from six.moves import xrange
from theano.gof.opt import Optimizer
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.compile import optdb from theano.compile import optdb
from theano.compile.function_module import deep_copy_op from theano.compile.function_module import deep_copy_op
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.gof.opt import Optimizer
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.scan_module import scan_op from theano.scan_module import scan_op
from theano.scan_module import scan_utils from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import equal_computations, find_up, \ from theano.scan_module.scan_utils import equal_computations, find_up, scan_args
scan_args
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
# Logging function for sending warning or info # Logging function for sending warning or info
...@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs) clean_outputs)
local_fgraph_outs_set = set(clean_outputs) local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in \ local_fgraph_outs_map = dict([(v, k) for k, v in
enumerate(clean_outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_map = OrderedDict() to_replace_map = OrderedDict()
nto_replace = 0
def add_to_replace(y): def add_to_replace(y):
to_replace_set.add(y) to_replace_set.add(y)
to_replace_map[y] = add_to_replace.n to_replace_map[y] = add_to_replace.n
add_to_replace.n +=1 add_to_replace.n += 1
add_to_replace.n = 0 add_to_replace.n = 0
replace_with_in = [] replace_with_in = []
...@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)]) inner_non_seqs_map = dict([(v, k) for k, v in
enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
...@@ -275,7 +274,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -275,7 +274,7 @@ class PushOutNonSeqScan(gof.Optimizer):
assert len(inner_seqs) == len(outer_seqs) assert len(inner_seqs) == len(outer_seqs)
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if (# we haven't already looked at this node if ( # we haven't already looked at this node
nd not in to_remove_set and nd not in to_remove_set and
all([((x in inner_non_seqs_set) or all([((x in inner_non_seqs_set) or
(x.owner in to_remove_set) or (x.owner in to_remove_set) or
...@@ -337,7 +336,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -337,7 +336,7 @@ class PushOutNonSeqScan(gof.Optimizer):
to_keep_set.update(nd.inputs) to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if (# If types are different, conversion Op will be inserted, if ( # If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop. # and it may trigger an infinite loop.
replace_with_in[idx].type == out.type and replace_with_in[idx].type == out.type and
out in to_keep_set and out in to_keep_set and
...@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs) clean_outputs)
local_fgraph_outs_set = set(clean_outputs) local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v,k) for k,v in \ local_fgraph_outs_map = dict([(v, k) for k, v in
enumerate(clean_outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_map = OrderedDict() to_replace_map = OrderedDict()
nto_replace = 0
def add_to_replace(y): def add_to_replace(y):
to_replace_set.add(y) to_replace_set.add(y)
...@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)]) inner_non_seqs_map = dict([(v, k) for k, v in
enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs) inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs) inner_seqs_set = set(inner_seqs)
inner_seqs_map = dict([(v,k) for k,v in enumerate(inner_seqs)]) inner_seqs_map = dict([(v, k) for k, v in
enumerate(inner_seqs)])
outer_seqs = op.outer_seqs(node.inputs) outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs) assert len(inner_non_seqs) == len(outer_non_seqs)
...@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer):
to_keep_set.update(nd.inputs) to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if (out in to_keep_set if (out in to_keep_set and out.owner not in existent_nodes_set and
and out.owner not in existent_nodes_set
# If types are different, conversion Op will be inserted, # If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop. # and it may trigger an infinite loop.
and replace_with_in[idx].type == out.type): replace_with_in[idx].type == out.type):
clean_to_replace.append(out) clean_to_replace.append(out)
clean_replace_with_in.append(replace_with_in[idx]) clean_replace_with_in.append(replace_with_in[idx])
...@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer):
not x.op.as_while)] not x.op.as_while)]
for node in nodelist: for node in nodelist:
# Process the node as long as something gets optimized # Process the node as long as something gets optimized
while node != None: while node is not None:
node = self.process_node(fgraph, node) node = self.process_node(fgraph, node)
def process_node(self, fgraph, node): def process_node(self, fgraph, node):
...@@ -778,9 +777,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -778,9 +777,8 @@ class PushOutScanOutput(gof.Optimizer):
outer_dot_output = theano.tensor.dot(*outer_dot_inputs) outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
# Modify the outer graph to add the outer Dot # Modify the outer graph to add the outer Dot
fgraph.replace_all([ fgraph.replace_all(
(new_scan_args.outer_out_nit_sot[ [(new_scan_args.outer_out_nit_sot[dot_out_nitsot_idx],
dot_out_nitsot_idx],
outer_dot_output)], outer_dot_output)],
reason="scanOp_pushout_output") reason="scanOp_pushout_output")
...@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer):
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[ sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[
sitsot_idx]) sitsot_idx])
dot_in_idx = 1 - sitsot_in_idx # 0 if sitsot_in_idx==1, # 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
# 1 if sitsot_in_idx==0 dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx] dot_input = nd.inputs[dot_in_idx]
if (dot_input.owner is not None and if (dot_input.owner is not None and
...@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer):
len(dot_input.clients) == 1 and len(dot_input.clients) == 1 and
dot_input.owner.inputs[0].ndim == 2 and dot_input.owner.inputs[0].ndim == 2 and
dot_input.owner.inputs[1].ndim == 2 and dot_input.owner.inputs[1].ndim == 2 and
self.get_outer_ndim(dot_input.owner.inputs[0], args) \ self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3 and
== 3 and self.get_outer_ndim(dot_input.owner.inputs[1], args) == 3):
self.get_outer_ndim(dot_input.owner.inputs[1], args) \
== 3):
# The optimization can be be applied in this case. # The optimization can be be applied in this case.
...@@ -829,8 +826,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -829,8 +826,7 @@ class PushOutScanOutput(gof.Optimizer):
(outer_dot_inputs, (outer_dot_inputs,
new_scan_node, new_scan_node,
new_scan_args) = \ new_scan_args) = \
self.push_out_inner_vars(fgraph, self.push_out_inner_vars(fgraph, inner_dot_inputs,
inner_dot_inputs,
node, args) node, args)
# Collapse some of the dimensions of the tensors # Collapse some of the dimensions of the tensors
...@@ -838,8 +834,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -838,8 +834,7 @@ class PushOutScanOutput(gof.Optimizer):
# dot is usually faster on two large matrices than # dot is usually faster on two large matrices than
# a bunch of small ones # a bunch of small ones
outer_dot_inputs[0] = theano.tensor.flatten( outer_dot_inputs[0] = theano.tensor.flatten(
outer_dot_inputs[0].dimshuffle(1, 0, 2), outer_dot_inputs[0].dimshuffle(1, 0, 2), outdim=2)
outdim=2)
shape_input1 = theano.tensor.shape(outer_dot_inputs[1]) shape_input1 = theano.tensor.shape(outer_dot_inputs[1])
outer_dot_inputs[1] =\ outer_dot_inputs[1] =\
...@@ -850,15 +845,13 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -850,15 +845,13 @@ class PushOutScanOutput(gof.Optimizer):
# Perform the dot on the newly obtained matrices and # Perform the dot on the newly obtained matrices and
# add the initial value # add the initial value
outer_dot_output = theano.tensor.dot(*outer_dot_inputs) outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
init_value = \ init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0]
new_scan_args.outer_in_sit_sot[sitsot_idx][0]
replacement = outer_dot_output + init_value replacement = outer_dot_output + init_value
# Alter the outer graph to use the output of the # Alter the outer graph to use the output of the
# external Dot instead of the output of scan # external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot # Modify the outer graph to add the outer Dot
outer_sitsot = \ outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
new_scan_args.outer_out_sit_sot[sitsot_idx]
subtensor_node = outer_sitsot.clients[0][0] subtensor_node = outer_sitsot.clients[0][0]
outer_sitsot_last_step = subtensor_node.outputs[0] outer_sitsot_last_step = subtensor_node.outputs[0]
...@@ -883,8 +876,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -883,8 +876,8 @@ class PushOutScanOutput(gof.Optimizer):
if len(outer_var.clients) == 1: if len(outer_var.clients) == 1:
client = outer_var.clients[0][0] client = outer_var.clients[0][0]
if (client != 'output' and if (client != 'output' and isinstance(client.op,
isinstance(client.op, theano.tensor.Subtensor)): theano.tensor.Subtensor)):
lst = theano.tensor.subtensor.get_idx_list( lst = theano.tensor.subtensor.get_idx_list(
client.inputs, client.op.idx_list) client.inputs, client.op.idx_list)
if (len(lst) == 1 and if (len(lst) == 1 and
...@@ -991,7 +984,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -991,7 +984,7 @@ class PushOutScanOutput(gof.Optimizer):
new_node_old_outputs = ( new_node_old_outputs = (
new_scan_node.outputs[:new_node_new_outputs_idx] + new_scan_node.outputs[:new_node_new_outputs_idx] +
new_scan_node.outputs[new_node_new_outputs_idx+nb_new_outs:]) new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs:])
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
list(zip(old_scan_node.outputs, new_node_old_outputs)), list(zip(old_scan_node.outputs, new_node_old_outputs)),
...@@ -1242,7 +1235,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1242,7 +1235,7 @@ class ScanSaveMem(gof.Optimizer):
for cl, _ in out.clients: for cl, _ in out.clients:
# 2.1 outputs of the function # 2.1 outputs of the function
#=> output needs all its intermediate values # => output needs all its intermediate values
if type(cl) == str: if type(cl) == str:
# if the node is actually an output, then # if the node is actually an output, then
# we need to store the entire thing # we need to store the entire thing
...@@ -1250,20 +1243,20 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1250,20 +1243,20 @@ class ScanSaveMem(gof.Optimizer):
slices[i] = None slices[i] = None
break break
# 2.2 non-subtensor nodes # 2.2 non-subtensor nodes
#=> output needs all its intermediate values # => output needs all its intermediate values
elif not isinstance(cl.op, tensor.Subtensor): elif not isinstance(cl.op, tensor.Subtensor):
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
# 2.3 subtensor nodes # 2.3 subtensor nodes
#=> output might need to store just a subset of its values # => output might need to store just a subset of its values
else: else:
# 2.3.1 extract idx list of subtensor # 2.3.1 extract idx list of subtensor
this_slice = tensor.get_idx_list(cl.inputs, this_slice = tensor.get_idx_list(cl.inputs,
cl.op.idx_list) cl.op.idx_list)
if this_slice is None: if this_slice is None:
# if unable to extract idx_list # if unable to extract idx_list
#=> outputs needs all its intermediate values # => outputs needs all its intermediate values
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
...@@ -1406,8 +1399,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1406,8 +1399,7 @@ class ScanSaveMem(gof.Optimizer):
# for mitsots and sitsots (because mitmots are not # for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if # currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated. # the pre-allocation mechanism is activated.
prealloc_outs = \ prealloc_outs = theano.config.scan.allow_output_prealloc
theano.config.scan.allow_output_prealloc
first_mitsot_idx = node.op.n_mit_mot first_mitsot_idx = node.op.n_mit_mot
last_sitsot_idx = (node.op.n_mit_mot + last_sitsot_idx = (node.op.n_mit_mot +
...@@ -1433,7 +1425,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1433,7 +1425,7 @@ class ScanSaveMem(gof.Optimizer):
# currently. # currently.
# pval = pre_greedy_local_optimizer(list_opt_slice, # pval = pre_greedy_local_optimizer(list_opt_slice,
# pval) # pval)
#pval = pre_constant_merge([pval])[0] # pval = pre_constant_merge([pval])[0]
# if (isinstance(pval, theano.tensor.TensorConstant) # if (isinstance(pval, theano.tensor.TensorConstant)
# and # and
# pval.dtype.startswith('int')): # pval.dtype.startswith('int')):
...@@ -1554,7 +1546,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1554,7 +1546,7 @@ class ScanSaveMem(gof.Optimizer):
nw_steps) nw_steps)
nw_inputs[in_idx] = nw_input nw_inputs[in_idx] = nw_input
else: else:
nw_input = nw_inputs[in_idx][:(initl+nw_steps)] nw_input = nw_inputs[in_idx][:(initl + nw_steps)]
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
in_idx = offset + idx + op.n_shared_outs in_idx = offset + idx + op.n_shared_outs
...@@ -1640,8 +1632,8 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1640,8 +1632,8 @@ class ScanSaveMem(gof.Optimizer):
stop = None stop = None
nw_slice = ((slice(sanitize(start), nw_slice = ((slice(sanitize(start),
sanitize(stop), sanitize(stop),
sanitize(cnf_slice[0].step)),) sanitize(cnf_slice[0].step)),) +
+ tuple(old_slices[1:])) tuple(old_slices[1:]))
else: else:
position = (cnf_slice[0] - nw_steps - position = (cnf_slice[0] - nw_steps -
...@@ -1662,8 +1654,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1662,8 +1654,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.9. Get replace pairs for all other nodes # 3.9. Get replace pairs for all other nodes
if flag_store or global_nsteps is not None: if flag_store or global_nsteps is not None:
for idx, o in enumerate(node.outputs): for idx, o in enumerate(node.outputs):
if not (idx in replaced_outs) and \ if not (idx in replaced_outs) and idx not in not_required:
not idx in not_required:
nw_pos = compress_map[idx] nw_pos = compress_map[idx]
old_new += [(o, new_outs[nw_pos])] old_new += [(o, new_outs[nw_pos])]
# Check if the new outputs depend on the old scan node # Check if the new outputs depend on the old scan node
...@@ -2072,8 +2063,8 @@ def scan_merge_inouts(node): ...@@ -2072,8 +2063,8 @@ def scan_merge_inouts(node):
# because they could have different sizes, and the corresponding # because they could have different sizes, and the corresponding
# outer outputs cannot be merged in that case. # outer outputs cannot be merged in that case.
for s_outer_i, s_inner_o, s_outer_o in seen: for s_outer_i, s_inner_o, s_outer_o in seen:
if (equal_computations([inner_o], [s_inner_o], left, right) if (equal_computations([inner_o], [s_inner_o], left, right) and
and outer_i == s_outer_i): outer_i == s_outer_i):
return s_outer_o return s_outer_o
seen.append((outer_i, inner_o, outer_o)) seen.append((outer_i, inner_o, outer_o))
return outer_o return outer_o
...@@ -2116,9 +2107,10 @@ def scan_merge_inouts(node): ...@@ -2116,9 +2107,10 @@ def scan_merge_inouts(node):
na.outer_out_mit_mot, na.outer_out_mit_mot,
na.mit_mot_out_slices): na.mit_mot_out_slices):
for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen: for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen:
if (osl == sosl if (osl == sosl and
and equal_computations(inner_omm, s_inner_omm, left, right) equal_computations(inner_omm, s_inner_omm, left, right) and
and outer_imm == s_outer_imm): outer_imm == s_outer_imm):
new_outer_out_mit_mot.append(s_outer_omm) new_outer_out_mit_mot.append(s_outer_omm)
break break
else: else:
...@@ -2168,17 +2160,15 @@ class PushOutDot1(gof.Optimizer): ...@@ -2168,17 +2160,15 @@ class PushOutDot1(gof.Optimizer):
inp in out.owner.inputs and inp in out.owner.inputs and
len(outer_out.clients) == 1 and len(outer_out.clients) == 1 and
not isinstance(outer_out.clients[0][0], str) and not isinstance(outer_out.clients[0][0], str) and
isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor) isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor) and
and outer_out.clients[0][0].op.idx_list == (-1,)): outer_out.clients[0][0].op.idx_list == (-1,)):
x = out.owner.inputs[0] x = out.owner.inputs[0]
if x == inp: if x == inp:
x = out.owner.inputs[1] x = out.owner.inputs[1]
# We need to check if x is the result of an outer product # We need to check if x is the result of an outer product
if (x.owner and if (x.owner and isinstance(x.owner.op, theano.tensor.Dot) and
isinstance(x.owner.op, theano.tensor.Dot) and x.owner.inputs[0].ndim == 2 and x.owner.inputs[1].ndim == 2):
x.owner.inputs[0].ndim == 2 and
x.owner.inputs[1].ndim == 2):
# We need to check if any of the inputs are a sequence # We need to check if any of the inputs are a sequence
inp1 = x.owner.inputs[0] inp1 = x.owner.inputs[0]
...@@ -2219,18 +2209,17 @@ class PushOutDot1(gof.Optimizer): ...@@ -2219,18 +2209,17 @@ class PushOutDot1(gof.Optimizer):
new_info = op.info.copy() new_info = op.info.copy()
st = len(op.mitmot_taps()) + len(op.mitsot_taps()) st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info['tap_array'] = (\ new_info['tap_array'] = (
new_info['tap_array'][:st + idx] + new_info['tap_array'][:st + idx] +
new_info['tap_array'][st + new_info['tap_array'][st + idx + 1:])
idx + 1:])
new_info['n_sit_sot'] -= 1 new_info['n_sit_sot'] -= 1
new_info['n_nit_sot'] += 1 new_info['n_nit_sot'] += 1
inner_sitsot = inner_sitsot[:idx] + \ inner_sitsot = (inner_sitsot[:idx] +
inner_sitsot[idx + 1:] inner_sitsot[idx + 1:])
outer_sitsot = outer_sitsot[:idx] + \ outer_sitsot = (outer_sitsot[:idx] +
outer_sitsot[idx + 1:] outer_sitsot[idx + 1:])
inner_sitsot_outs = inner_sitsot_outs[:idx] +\ inner_sitsot_outs = (inner_sitsot_outs[:idx] +
inner_sitsot_outs[idx + 1:] inner_sitsot_outs[idx + 1:])
# add n_steps as the length # add n_steps as the length
inner_nitsot_outs.append(new_scan_out) inner_nitsot_outs.append(new_scan_out)
...@@ -2246,8 +2235,8 @@ class PushOutDot1(gof.Optimizer): ...@@ -2246,8 +2235,8 @@ class PushOutDot1(gof.Optimizer):
inner_nitsot_outs + inner_nitsot_outs +
inner_shared_outs) inner_shared_outs)
new_inner_inps, new_inner_outs =\ new_inner_inps, new_inner_outs =\
scan_utils.reconstruct_graph( scan_utils.reconstruct_graph(_new_inner_inps,
_new_inner_inps, _new_inner_outs) _new_inner_outs)
new_op = scan_op.Scan(new_inner_inps, new_inner_outs, new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
new_info) new_info)
_scan_inputs = ([node.inputs[0]] + _scan_inputs = ([node.inputs[0]] +
...@@ -2267,11 +2256,7 @@ class PushOutDot1(gof.Optimizer): ...@@ -2267,11 +2256,7 @@ class PushOutDot1(gof.Optimizer):
# We need now to pair correctly the new outputs # We need now to pair correctly the new outputs
# with the old ones # with the old ones
outer_mitmot_outs = new_op.outer_mitmot_outs(new_outs)
outer_mitsot_outs = new_op.outer_mitsot_outs(new_outs)
outer_sitsot_outs = new_op.outer_sitsot_outs(new_outs)
outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs) outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs)
outer_shared_outs = new_op.outer_shared_outs(new_outs)
_val = outer_nitsot_outs[-1] _val = outer_nitsot_outs[-1]
outer_nitsot_outs = outer_nitsot_outs[:-1] outer_nitsot_outs = outer_nitsot_outs[:-1]
...@@ -2305,7 +2290,7 @@ class PushOutDot1(gof.Optimizer): ...@@ -2305,7 +2290,7 @@ class PushOutDot1(gof.Optimizer):
old_new = list(zip(node.outputs[:pos], new_outs[:pos])) old_new = list(zip(node.outputs[:pos], new_outs[:pos]))
old = node.outputs[pos].clients[0][0].outputs[0] old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out)) old_new.append((old, new_out))
old_new += list(zip(node.outputs[pos+1:], old_new += list(zip(node.outputs[pos + 1:],
new_outs[pos:])) new_outs[pos:]))
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
old_new, remove=[node], reason='scan_pushout_dot1') old_new, remove=[node], reason='scan_pushout_dot1')
......
...@@ -164,7 +164,6 @@ whitelist_flake8 = [ ...@@ -164,7 +164,6 @@ whitelist_flake8 = [
"scan_module/scan_op.py", "scan_module/scan_op.py",
"scan_module/scan_perform_ext.py", "scan_module/scan_perform_ext.py",
"scan_module/__init__.py", "scan_module/__init__.py",
"scan_module/scan_opt.py",
"scan_module/tests/test_scan.py", "scan_module/tests/test_scan.py",
"scan_module/tests/test_scan_opt.py", "scan_module/tests/test_scan_opt.py",
"misc/tests/test_may_share_memory.py", "misc/tests/test_may_share_memory.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论