提交 1b8fa84e authored 作者: carriepl's avatar carriepl

Flake8 on scan_opt.py

上级 2928d02a
...@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global). ...@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
in2out(scan_merge_inouts), in2out(scan_merge_inouts),
ScanSaveMem, ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3) in2out(remove_constants_and_unused_inputs_scan3)
""" """
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import logging import logging
import copy import copy
...@@ -66,21 +57,29 @@ import numpy ...@@ -66,21 +57,29 @@ import numpy
import theano import theano
from theano import tensor from theano import tensor
from theano.tensor import opt, get_scalar_constant_value from theano.tensor import opt, get_scalar_constant_value, Alloc, AllocEmpty
from theano import gof from theano import gof
from theano.compat import OrderedDict from theano.compat import OrderedDict
from six import integer_types, iteritems from six import integer_types, iteritems
from six.moves import xrange from six.moves import xrange
from theano.gof.opt import Optimizer
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.compile import optdb from theano.compile import optdb
from theano.compile.function_module import deep_copy_op from theano.compile.function_module import deep_copy_op
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.gof.opt import Optimizer
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.scan_module import scan_op from theano.scan_module import scan_op
from theano.scan_module import scan_utils from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import equal_computations, find_up, \ from theano.scan_module.scan_utils import equal_computations, find_up, scan_args
scan_args
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
# Logging function for sending warning or info # Logging function for sending warning or info
...@@ -151,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -151,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(node):
for idx in xrange(op.n_seqs): for idx in xrange(op.n_seqs):
node_inp = node.inputs[idx + 1] node_inp = node.inputs[idx + 1]
if (isinstance(node_inp, tensor.TensorConstant) and if (isinstance(node_inp, tensor.TensorConstant) and
node_inp.tag.unique_value is not None): node_inp.tag.unique_value is not None):
try: try:
# This works if input is a constant that has all entries # This works if input is a constant that has all entries
# equal # equal
...@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs) clean_outputs)
local_fgraph_outs_set = set(clean_outputs) local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in \ local_fgraph_outs_map = dict([(v, k) for k, v in
enumerate(clean_outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_map = OrderedDict() to_replace_map = OrderedDict()
nto_replace = 0
def add_to_replace(y): def add_to_replace(y):
to_replace_set.add(y) to_replace_set.add(y)
to_replace_map[y] = add_to_replace.n to_replace_map[y] = add_to_replace.n
add_to_replace.n +=1 add_to_replace.n += 1
add_to_replace.n = 0 add_to_replace.n = 0
replace_with_in = [] replace_with_in = []
...@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)]) inner_non_seqs_map = dict([(v, k) for k, v in
enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
...@@ -275,17 +274,17 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -275,17 +274,17 @@ class PushOutNonSeqScan(gof.Optimizer):
assert len(inner_seqs) == len(outer_seqs) assert len(inner_seqs) == len(outer_seqs)
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if (# we haven't already looked at this node if ( # we haven't already looked at this node
nd not in to_remove_set and nd not in to_remove_set and
all([((x in inner_non_seqs_set) or all([((x in inner_non_seqs_set) or
(x.owner in to_remove_set) or (x.owner in to_remove_set) or
isinstance(x, tensor.Constant)) isinstance(x, tensor.Constant))
for x in nd.inputs]) and for x in nd.inputs]) and
# we can do this because the assumption is that a # we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the # viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle .. # function and not somewhere in the middle ..
not isinstance(nd.op, theano.compile.ViewOp) and not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp)): not isinstance(nd.op, theano.compile.DeepCopyOp)):
# We have a candidate node to removable # We have a candidate node to removable
# Step 1. Reconstruct it on outside # Step 1. Reconstruct it on outside
...@@ -337,11 +336,11 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -337,11 +336,11 @@ class PushOutNonSeqScan(gof.Optimizer):
to_keep_set.update(nd.inputs) to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if (# If types are different, conversion Op will be inserted, if ( # If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop. # and it may trigger an infinite loop.
replace_with_in[idx].type == out.type and replace_with_in[idx].type == out.type and
out in to_keep_set and out in to_keep_set and
out.owner not in existent_nodes_set): out.owner not in existent_nodes_set):
clean_to_replace.append(out) clean_to_replace.append(out)
clean_replace_with_in.append(replace_with_in[idx]) clean_replace_with_in.append(replace_with_in[idx])
clean_replace_with_out.append(replace_with_out[idx]) clean_replace_with_out.append(replace_with_out[idx])
...@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs) clean_outputs)
local_fgraph_outs_set = set(clean_outputs) local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v,k) for k,v in \ local_fgraph_outs_map = dict([(v, k) for k, v in
enumerate(clean_outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_map = OrderedDict() to_replace_map = OrderedDict()
nto_replace = 0
def add_to_replace(y): def add_to_replace(y):
to_replace_set.add(y) to_replace_set.add(y)
...@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)]) inner_non_seqs_map = dict([(v, k) for k, v in
enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs) inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs) inner_seqs_set = set(inner_seqs)
inner_seqs_map = dict([(v,k) for k,v in enumerate(inner_seqs)]) inner_seqs_map = dict([(v, k) for k, v in
enumerate(inner_seqs)])
outer_seqs = op.outer_seqs(node.inputs) outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs) assert len(inner_non_seqs) == len(outer_non_seqs)
...@@ -488,7 +488,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -488,7 +488,7 @@ class PushOutSeqScan(gof.Optimizer):
(x.owner in to_remove_set) or (x.owner in to_remove_set) or
isinstance(x, tensor.Constant) or isinstance(x, tensor.Constant) or
(x in inner_seqs_set) for x in nd.inputs]) and (x in inner_seqs_set) for x in nd.inputs]) and
isinstance(nd.op, theano.tensor.Elemwise)): isinstance(nd.op, theano.tensor.Elemwise)):
outside_ins = [] outside_ins = []
depends_on_seqs = False depends_on_seqs = False
...@@ -538,7 +538,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -538,7 +538,7 @@ class PushOutSeqScan(gof.Optimizer):
elif (nd not in to_remove_set and elif (nd not in to_remove_set and
isinstance(nd.op, theano.tensor.DimShuffle) and isinstance(nd.op, theano.tensor.DimShuffle) and
(nd.inputs[0] in inner_seqs_set or (nd.inputs[0] in inner_seqs_set or
nd.inputs[0].owner in to_remove_set)): nd.inputs[0].owner in to_remove_set)):
to_remove_set.add(nd) to_remove_set.add(nd)
x = nd.inputs[0] x = nd.inputs[0]
...@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer):
to_keep_set.update(nd.inputs) to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if (out in to_keep_set if (out in to_keep_set and out.owner not in existent_nodes_set and
and out.owner not in existent_nodes_set # If types are different, conversion Op will be inserted,
# If types are different, conversion Op will be inserted, # and it may trigger an infinite loop.
# and it may trigger an infinite loop. replace_with_in[idx].type == out.type):
and replace_with_in[idx].type == out.type):
clean_to_replace.append(out) clean_to_replace.append(out)
clean_replace_with_in.append(replace_with_in[idx]) clean_replace_with_in.append(replace_with_in[idx])
...@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer):
not x.op.as_while)] not x.op.as_while)]
for node in nodelist: for node in nodelist:
# Process the node as long as something gets optimized # Process the node as long as something gets optimized
while node != None: while node is not None:
node = self.process_node(fgraph, node) node = self.process_node(fgraph, node)
def process_node(self, fgraph, node): def process_node(self, fgraph, node):
...@@ -702,7 +701,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -702,7 +701,7 @@ class PushOutScanOutput(gof.Optimizer):
local_fgraph_topo = local_fgraph.toposort() local_fgraph_topo = local_fgraph.toposort()
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if (isinstance(nd.op, theano.tensor.Dot) and if (isinstance(nd.op, theano.tensor.Dot) and
nd.out in args.inner_out_nit_sot): nd.out in args.inner_out_nit_sot):
""" """
The following optimization involves pushing out, after the The following optimization involves pushing out, after the
scan, a Dot whose output is nitsot (not feed back to the inner scan, a Dot whose output is nitsot (not feed back to the inner
...@@ -737,8 +736,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -737,8 +736,8 @@ class PushOutScanOutput(gof.Optimizer):
(nd.inputs[0] in args.inner_in_non_seqs or (nd.inputs[0] in args.inner_in_non_seqs or
isinstance(nd.inputs[0], tensor.Constant)) and isinstance(nd.inputs[0], tensor.Constant)) and
nd.inputs[1].ndim == 1 and nd.inputs[1].ndim == 1 and
(nd.inputs[1] in args.inner_in_seqs or (nd.inputs[1] in args.inner_in_seqs or
nd.inputs[1] not in args.inner_inputs)): nd.inputs[1] not in args.inner_inputs)):
valid_inputs = True valid_inputs = True
idx_matrix_input = 0 idx_matrix_input = 0
...@@ -778,11 +777,10 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -778,11 +777,10 @@ class PushOutScanOutput(gof.Optimizer):
outer_dot_output = theano.tensor.dot(*outer_dot_inputs) outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
# Modify the outer graph to add the outer Dot # Modify the outer graph to add the outer Dot
fgraph.replace_all([ fgraph.replace_all(
(new_scan_args.outer_out_nit_sot[ [(new_scan_args.outer_out_nit_sot[dot_out_nitsot_idx],
dot_out_nitsot_idx], outer_dot_output)],
outer_dot_output)], reason="scanOp_pushout_output")
reason="scanOp_pushout_output")
break break
...@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer):
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[ sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[
sitsot_idx]) sitsot_idx])
dot_in_idx = 1 - sitsot_in_idx # 0 if sitsot_in_idx==1, # 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
# 1 if sitsot_in_idx==0 dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx] dot_input = nd.inputs[dot_in_idx]
if (dot_input.owner is not None and if (dot_input.owner is not None and
...@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer):
len(dot_input.clients) == 1 and len(dot_input.clients) == 1 and
dot_input.owner.inputs[0].ndim == 2 and dot_input.owner.inputs[0].ndim == 2 and
dot_input.owner.inputs[1].ndim == 2 and dot_input.owner.inputs[1].ndim == 2 and
self.get_outer_ndim(dot_input.owner.inputs[0], args) \ self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3 and
== 3 and self.get_outer_ndim(dot_input.owner.inputs[1], args) == 3):
self.get_outer_ndim(dot_input.owner.inputs[1], args) \
== 3):
# The optimization can be be applied in this case. # The optimization can be be applied in this case.
...@@ -829,36 +826,32 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -829,36 +826,32 @@ class PushOutScanOutput(gof.Optimizer):
(outer_dot_inputs, (outer_dot_inputs,
new_scan_node, new_scan_node,
new_scan_args) = \ new_scan_args) = \
self.push_out_inner_vars(fgraph, self.push_out_inner_vars(fgraph, inner_dot_inputs,
inner_dot_inputs, node, args)
node, args)
# Collapse some of the dimensions of the tensors # Collapse some of the dimensions of the tensors
# so that they become matrices. This is because a # so that they become matrices. This is because a
# dot is usually faster on two large matrices than # dot is usually faster on two large matrices than
# a bunch of small ones # a bunch of small ones
outer_dot_inputs[0] = theano.tensor.flatten( outer_dot_inputs[0] = theano.tensor.flatten(
outer_dot_inputs[0].dimshuffle(1, 0, 2), outer_dot_inputs[0].dimshuffle(1, 0, 2), outdim=2)
outdim=2)
shape_input1 = theano.tensor.shape(outer_dot_inputs[1]) shape_input1 = theano.tensor.shape(outer_dot_inputs[1])
outer_dot_inputs[1] =\ outer_dot_inputs[1] =\
outer_dot_inputs[1].reshape((shape_input1[0] * outer_dot_inputs[1].reshape((shape_input1[0] *
shape_input1[1], shape_input1[1],
shape_input1[2])) shape_input1[2]))
# Perform the dot on the newly obtained matrices and # Perform the dot on the newly obtained matrices and
# add the initial value # add the initial value
outer_dot_output = theano.tensor.dot(*outer_dot_inputs) outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
init_value = \ init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0]
new_scan_args.outer_in_sit_sot[sitsot_idx][0]
replacement = outer_dot_output + init_value replacement = outer_dot_output + init_value
# Alter the outer graph to use the output of the # Alter the outer graph to use the output of the
# external Dot instead of the output of scan # external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot # Modify the outer graph to add the outer Dot
outer_sitsot = \ outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
new_scan_args.outer_out_sit_sot[sitsot_idx]
subtensor_node = outer_sitsot.clients[0][0] subtensor_node = outer_sitsot.clients[0][0]
outer_sitsot_last_step = subtensor_node.outputs[0] outer_sitsot_last_step = subtensor_node.outputs[0]
...@@ -883,12 +876,12 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -883,12 +876,12 @@ class PushOutScanOutput(gof.Optimizer):
if len(outer_var.clients) == 1: if len(outer_var.clients) == 1:
client = outer_var.clients[0][0] client = outer_var.clients[0][0]
if (client != 'output' and if (client != 'output' and isinstance(client.op,
isinstance(client.op, theano.tensor.Subtensor)): theano.tensor.Subtensor)):
lst = theano.tensor.subtensor.get_idx_list( lst = theano.tensor.subtensor.get_idx_list(
client.inputs, client.op.idx_list) client.inputs, client.op.idx_list)
if (len(lst) == 1 and if (len(lst) == 1 and
theano.tensor.extract_constant(lst[0]) == -1): theano.tensor.extract_constant(lst[0]) == -1):
return True return True
return False return False
...@@ -898,7 +891,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -898,7 +891,7 @@ class PushOutScanOutput(gof.Optimizer):
# Given a variable, determine the number of dimension it would have if # Given a variable, determine the number of dimension it would have if
# it was pushed out of scan # it was pushed out of scan
if (var in scan_args.inner_in_non_seqs or if (var in scan_args.inner_in_non_seqs or
isinstance(var, theano.Constant)): isinstance(var, theano.Constant)):
outer_ndim = var.ndim outer_ndim = var.ndim
else: else:
...@@ -990,8 +983,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -990,8 +983,8 @@ class PushOutScanOutput(gof.Optimizer):
len(old_scan_args.outer_out_shared)) len(old_scan_args.outer_out_shared))
new_node_old_outputs = ( new_node_old_outputs = (
new_scan_node.outputs[:new_node_new_outputs_idx] + new_scan_node.outputs[:new_node_new_outputs_idx] +
new_scan_node.outputs[new_node_new_outputs_idx+nb_new_outs:]) new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs:])
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
list(zip(old_scan_node.outputs, new_node_old_outputs)), list(zip(old_scan_node.outputs, new_node_old_outputs)),
...@@ -1242,7 +1235,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1242,7 +1235,7 @@ class ScanSaveMem(gof.Optimizer):
for cl, _ in out.clients: for cl, _ in out.clients:
# 2.1 outputs of the function # 2.1 outputs of the function
#=> output needs all its intermediate values # => output needs all its intermediate values
if type(cl) == str: if type(cl) == str:
# if the node is actually an output, then # if the node is actually an output, then
# we need to store the entire thing # we need to store the entire thing
...@@ -1250,20 +1243,20 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1250,20 +1243,20 @@ class ScanSaveMem(gof.Optimizer):
slices[i] = None slices[i] = None
break break
# 2.2 non-subtensor nodes # 2.2 non-subtensor nodes
#=> output needs all its intermediate values # => output needs all its intermediate values
elif not isinstance(cl.op, tensor.Subtensor): elif not isinstance(cl.op, tensor.Subtensor):
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
# 2.3 subtensor nodes # 2.3 subtensor nodes
#=> output might need to store just a subset of its values # => output might need to store just a subset of its values
else: else:
# 2.3.1 extract idx list of subtensor # 2.3.1 extract idx list of subtensor
this_slice = tensor.get_idx_list(cl.inputs, this_slice = tensor.get_idx_list(cl.inputs,
cl.op.idx_list) cl.op.idx_list)
if this_slice is None: if this_slice is None:
# if unable to extract idx_list # if unable to extract idx_list
#=> outputs needs all its intermediate values # => outputs needs all its intermediate values
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
...@@ -1284,7 +1277,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1284,7 +1277,7 @@ class ScanSaveMem(gof.Optimizer):
slices[i] += [(cf_slice, this_slice)] slices[i] += [(cf_slice, this_slice)]
if (isinstance(this_slice[0], slice) and if (isinstance(this_slice[0], slice) and
this_slice[0].stop is None): this_slice[0].stop is None):
global_nsteps = None global_nsteps = None
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
stop = tensor.basic.extract_constant(cf_slice[0].stop) stop = tensor.basic.extract_constant(cf_slice[0].stop)
...@@ -1374,7 +1367,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1374,7 +1367,7 @@ class ScanSaveMem(gof.Optimizer):
break break
if (isinstance(this_slice[0], slice) and if (isinstance(this_slice[0], slice) and
this_slice[0].start is None): this_slice[0].start is None):
store_steps[i] = 0 store_steps[i] = 0
break break
...@@ -1406,8 +1399,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1406,8 +1399,7 @@ class ScanSaveMem(gof.Optimizer):
# for mitsots and sitsots (because mitmots are not # for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if # currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated. # the pre-allocation mechanism is activated.
prealloc_outs = \ prealloc_outs = theano.config.scan.allow_output_prealloc
theano.config.scan.allow_output_prealloc
first_mitsot_idx = node.op.n_mit_mot first_mitsot_idx = node.op.n_mit_mot
last_sitsot_idx = (node.op.n_mit_mot + last_sitsot_idx = (node.op.n_mit_mot +
...@@ -1433,7 +1425,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1433,7 +1425,7 @@ class ScanSaveMem(gof.Optimizer):
# currently. # currently.
# pval = pre_greedy_local_optimizer(list_opt_slice, # pval = pre_greedy_local_optimizer(list_opt_slice,
# pval) # pval)
#pval = pre_constant_merge([pval])[0] # pval = pre_constant_merge([pval])[0]
# if (isinstance(pval, theano.tensor.TensorConstant) # if (isinstance(pval, theano.tensor.TensorConstant)
# and # and
# pval.dtype.startswith('int')): # pval.dtype.startswith('int')):
...@@ -1554,7 +1546,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1554,7 +1546,7 @@ class ScanSaveMem(gof.Optimizer):
nw_steps) nw_steps)
nw_inputs[in_idx] = nw_input nw_inputs[in_idx] = nw_input
else: else:
nw_input = nw_inputs[in_idx][:(initl+nw_steps)] nw_input = nw_inputs[in_idx][:(initl + nw_steps)]
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
in_idx = offset + idx + op.n_shared_outs in_idx = offset + idx + op.n_shared_outs
...@@ -1563,7 +1555,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1563,7 +1555,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.5 Remove unwanted orphane outputs # 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \ (inps, outs, info, node_ins, compress_map) = \
scan_utils.compress_outs(op, not_required, nw_inputs) scan_utils.compress_outs(op, not_required, nw_inputs)
inv_compress_map = OrderedDict() inv_compress_map = OrderedDict()
for k, v in iteritems(compress_map): for k, v in iteritems(compress_map):
inv_compress_map[v] = k inv_compress_map[v] = k
...@@ -1633,15 +1625,15 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1633,15 +1625,15 @@ class ScanSaveMem(gof.Optimizer):
start = (cnf_slice[0].start - nw_steps - start = (cnf_slice[0].start - nw_steps -
init_l[pos] + store_steps[pos]) init_l[pos] + store_steps[pos])
if (cnf_slice[0].stop is not None and if (cnf_slice[0].stop is not None and
cnf_slice[0].stop != maxsize): cnf_slice[0].stop != maxsize):
stop = (cnf_slice[0].stop - nw_steps - stop = (cnf_slice[0].stop - nw_steps -
init_l[pos] + store_steps[pos]) init_l[pos] + store_steps[pos])
else: else:
stop = None stop = None
nw_slice = ((slice(sanitize(start), nw_slice = ((slice(sanitize(start),
sanitize(stop), sanitize(stop),
sanitize(cnf_slice[0].step)),) sanitize(cnf_slice[0].step)),) +
+ tuple(old_slices[1:])) tuple(old_slices[1:]))
else: else:
position = (cnf_slice[0] - nw_steps - position = (cnf_slice[0] - nw_steps -
...@@ -1662,8 +1654,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1662,8 +1654,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.9. Get replace pairs for all other nodes # 3.9. Get replace pairs for all other nodes
if flag_store or global_nsteps is not None: if flag_store or global_nsteps is not None:
for idx, o in enumerate(node.outputs): for idx, o in enumerate(node.outputs):
if not (idx in replaced_outs) and \ if not (idx in replaced_outs) and idx not in not_required:
not idx in not_required:
nw_pos = compress_map[idx] nw_pos = compress_map[idx]
old_new += [(o, new_outs[nw_pos])] old_new += [(o, new_outs[nw_pos])]
# Check if the new outputs depend on the old scan node # Check if the new outputs depend on the old scan node
...@@ -1808,7 +1799,7 @@ class ScanMerge(gof.Optimizer): ...@@ -1808,7 +1799,7 @@ class ScanMerge(gof.Optimizer):
flat_inner_outs = sum(inner_outs[idx], []) flat_inner_outs = sum(inner_outs[idx], [])
# clone # clone
flat_inner_ins, flat_inner_outs = scan_utils.reconstruct_graph( flat_inner_ins, flat_inner_outs = scan_utils.reconstruct_graph(
flat_inner_ins, flat_inner_outs) flat_inner_ins, flat_inner_outs)
# split the new inner variables again in seq, mitmot, etc. # split the new inner variables again in seq, mitmot, etc.
new_inner_ins = [] new_inner_ins = []
count = 0 count = 0
...@@ -2072,8 +2063,8 @@ def scan_merge_inouts(node): ...@@ -2072,8 +2063,8 @@ def scan_merge_inouts(node):
# because they could have different sizes, and the corresponding # because they could have different sizes, and the corresponding
# outer outputs cannot be merged in that case. # outer outputs cannot be merged in that case.
for s_outer_i, s_inner_o, s_outer_o in seen: for s_outer_i, s_inner_o, s_outer_o in seen:
if (equal_computations([inner_o], [s_inner_o], left, right) if (equal_computations([inner_o], [s_inner_o], left, right) and
and outer_i == s_outer_i): outer_i == s_outer_i):
return s_outer_o return s_outer_o
seen.append((outer_i, inner_o, outer_o)) seen.append((outer_i, inner_o, outer_o))
return outer_o return outer_o
...@@ -2116,9 +2107,10 @@ def scan_merge_inouts(node): ...@@ -2116,9 +2107,10 @@ def scan_merge_inouts(node):
na.outer_out_mit_mot, na.outer_out_mit_mot,
na.mit_mot_out_slices): na.mit_mot_out_slices):
for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen: for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen:
if (osl == sosl if (osl == sosl and
and equal_computations(inner_omm, s_inner_omm, left, right) equal_computations(inner_omm, s_inner_omm, left, right) and
and outer_imm == s_outer_imm): outer_imm == s_outer_imm):
new_outer_out_mit_mot.append(s_outer_omm) new_outer_out_mit_mot.append(s_outer_omm)
break break
else: else:
...@@ -2168,17 +2160,15 @@ class PushOutDot1(gof.Optimizer): ...@@ -2168,17 +2160,15 @@ class PushOutDot1(gof.Optimizer):
inp in out.owner.inputs and inp in out.owner.inputs and
len(outer_out.clients) == 1 and len(outer_out.clients) == 1 and
not isinstance(outer_out.clients[0][0], str) and not isinstance(outer_out.clients[0][0], str) and
isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor) isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor) and
and outer_out.clients[0][0].op.idx_list == (-1,)): outer_out.clients[0][0].op.idx_list == (-1,)):
x = out.owner.inputs[0] x = out.owner.inputs[0]
if x == inp: if x == inp:
x = out.owner.inputs[1] x = out.owner.inputs[1]
# We need to check if x is the result of an outer product # We need to check if x is the result of an outer product
if (x.owner and if (x.owner and isinstance(x.owner.op, theano.tensor.Dot) and
isinstance(x.owner.op, theano.tensor.Dot) and x.owner.inputs[0].ndim == 2 and x.owner.inputs[1].ndim == 2):
x.owner.inputs[0].ndim == 2 and
x.owner.inputs[1].ndim == 2):
# We need to check if any of the inputs are a sequence # We need to check if any of the inputs are a sequence
inp1 = x.owner.inputs[0] inp1 = x.owner.inputs[0]
...@@ -2219,18 +2209,17 @@ class PushOutDot1(gof.Optimizer): ...@@ -2219,18 +2209,17 @@ class PushOutDot1(gof.Optimizer):
new_info = op.info.copy() new_info = op.info.copy()
st = len(op.mitmot_taps()) + len(op.mitsot_taps()) st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info['tap_array'] = (\ new_info['tap_array'] = (
new_info['tap_array'][:st + idx] + new_info['tap_array'][:st + idx] +
new_info['tap_array'][st + new_info['tap_array'][st + idx + 1:])
idx + 1:])
new_info['n_sit_sot'] -= 1 new_info['n_sit_sot'] -= 1
new_info['n_nit_sot'] += 1 new_info['n_nit_sot'] += 1
inner_sitsot = inner_sitsot[:idx] + \ inner_sitsot = (inner_sitsot[:idx] +
inner_sitsot[idx + 1:] inner_sitsot[idx + 1:])
outer_sitsot = outer_sitsot[:idx] + \ outer_sitsot = (outer_sitsot[:idx] +
outer_sitsot[idx + 1:] outer_sitsot[idx + 1:])
inner_sitsot_outs = inner_sitsot_outs[:idx] +\ inner_sitsot_outs = (inner_sitsot_outs[:idx] +
inner_sitsot_outs[idx + 1:] inner_sitsot_outs[idx + 1:])
# add n_steps as the length # add n_steps as the length
inner_nitsot_outs.append(new_scan_out) inner_nitsot_outs.append(new_scan_out)
...@@ -2246,8 +2235,8 @@ class PushOutDot1(gof.Optimizer): ...@@ -2246,8 +2235,8 @@ class PushOutDot1(gof.Optimizer):
inner_nitsot_outs + inner_nitsot_outs +
inner_shared_outs) inner_shared_outs)
new_inner_inps, new_inner_outs =\ new_inner_inps, new_inner_outs =\
scan_utils.reconstruct_graph( scan_utils.reconstruct_graph(_new_inner_inps,
_new_inner_inps, _new_inner_outs) _new_inner_outs)
new_op = scan_op.Scan(new_inner_inps, new_inner_outs, new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
new_info) new_info)
_scan_inputs = ([node.inputs[0]] + _scan_inputs = ([node.inputs[0]] +
...@@ -2267,11 +2256,7 @@ class PushOutDot1(gof.Optimizer): ...@@ -2267,11 +2256,7 @@ class PushOutDot1(gof.Optimizer):
# We need now to pair correctly the new outputs # We need now to pair correctly the new outputs
# with the old ones # with the old ones
outer_mitmot_outs = new_op.outer_mitmot_outs(new_outs)
outer_mitsot_outs = new_op.outer_mitsot_outs(new_outs)
outer_sitsot_outs = new_op.outer_sitsot_outs(new_outs)
outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs) outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs)
outer_shared_outs = new_op.outer_shared_outs(new_outs)
_val = outer_nitsot_outs[-1] _val = outer_nitsot_outs[-1]
outer_nitsot_outs = outer_nitsot_outs[:-1] outer_nitsot_outs = outer_nitsot_outs[:-1]
...@@ -2305,7 +2290,7 @@ class PushOutDot1(gof.Optimizer): ...@@ -2305,7 +2290,7 @@ class PushOutDot1(gof.Optimizer):
old_new = list(zip(node.outputs[:pos], new_outs[:pos])) old_new = list(zip(node.outputs[:pos], new_outs[:pos]))
old = node.outputs[pos].clients[0][0].outputs[0] old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out)) old_new.append((old, new_out))
old_new += list(zip(node.outputs[pos+1:], old_new += list(zip(node.outputs[pos + 1:],
new_outs[pos:])) new_outs[pos:]))
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
old_new, remove=[node], reason='scan_pushout_dot1') old_new, remove=[node], reason='scan_pushout_dot1')
...@@ -2374,61 +2359,61 @@ scan_seqopt1.register('scanOp_pushout_output', ...@@ -2374,61 +2359,61 @@ scan_seqopt1.register('scanOp_pushout_output',
scan_eqopt2.register('constant_folding_for_scan2', scan_eqopt2.register('constant_folding_for_scan2',
opt.in2out(tensor.opt.constant_folding, opt.in2out(tensor.opt.constant_folding,
ignore_newtrees=True), ignore_newtrees=True),
1, 1,
'fast_run', 'fast_run',
'scan') 'scan')
scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs1', scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs1',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True), ignore_newtrees=True),
2, 2,
'remove_constants_and_unused_inputs_scan', 'remove_constants_and_unused_inputs_scan',
'fast_run', 'fast_run',
'scan') 'scan')
# after const merge but before stabilize so that we can have identity # after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out # for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later. # of the scan later.
scan_eqopt2.register('scanOp_merge', scan_eqopt2.register('scanOp_merge',
ScanMerge(), ScanMerge(),
4, 4,
'fast_run', 'fast_run',
'scan') 'scan')
# After Merge optimization # After Merge optimization
scan_eqopt2.register('scanop_remove_constants_and_unused_inputs2', scan_eqopt2.register('scanop_remove_constants_and_unused_inputs2',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True), ignore_newtrees=True),
5, 5,
'remove_constants_and_unused_inputs_scan', 'remove_constants_and_unused_inputs_scan',
'fast_run', 'fast_run',
'scan') 'scan')
scan_eqopt2.register('scanOp_merge_inouts', scan_eqopt2.register('scanOp_merge_inouts',
opt.in2out(scan_merge_inouts, ignore_newtrees=True), opt.in2out(scan_merge_inouts, ignore_newtrees=True),
6, 6,
'scan_merge_inouts', 'scan_merge_inouts',
'fast_run', 'fast_run',
'scan') 'scan')
# Just before specialize to have the other optimization # Just before specialize to have the other optimization
# like constant folding being applied # like constant folding being applied
# This don't introduce inplace. # This don't introduce inplace.
scan_eqopt2.register('scanOp_save_mem', scan_eqopt2.register('scanOp_save_mem',
ScanSaveMem(), ScanSaveMem(),
7, 7,
'fast_run', 'fast_run',
'scan') 'scan')
# After everything else # After everything else
scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs3', scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs3',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True), ignore_newtrees=True),
8, 8,
'remove_constants_and_unused_inputs_scan', 'remove_constants_and_unused_inputs_scan',
'fast_run', 'fast_run',
'scan') 'scan')
...@@ -164,7 +164,6 @@ whitelist_flake8 = [ ...@@ -164,7 +164,6 @@ whitelist_flake8 = [
"scan_module/scan_op.py", "scan_module/scan_op.py",
"scan_module/scan_perform_ext.py", "scan_module/scan_perform_ext.py",
"scan_module/__init__.py", "scan_module/__init__.py",
"scan_module/scan_opt.py",
"scan_module/tests/test_scan.py", "scan_module/tests/test_scan.py",
"scan_module/tests/test_scan_opt.py", "scan_module/tests/test_scan_opt.py",
"misc/tests/test_may_share_memory.py", "misc/tests/test_may_share_memory.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论