提交 c8dc3dbe authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3524 from carriepl/scan_inplace_opt

Scan inplace opt
......@@ -316,7 +316,7 @@ if cuda_available:
GpuDimShuffle, GpuCAReduce, GpuReshape, GpuContiguous,
GpuSubtensor, GpuIncSubtensor,
GpuAdvancedSubtensor1, GpuAdvancedIncSubtensor1,
GpuFlatten, GpuShape, GpuAlloc, GpuSplit,
GpuFlatten, GpuShape, GpuAlloc, GpuAllocEmpty, GpuSplit,
GpuJoin, fscalar, fvector, fmatrix, frow, fcol,
ftensor3, ftensor4,
scalar, vector, matrix, row, col,
......@@ -341,7 +341,7 @@ def use(device,
Parameters
----------
device : string
device : string
"cpu", "gpu", "gpuN" (N is the device number to use).
force
Will always raise an exception if we can't use the gpu.
......
......@@ -68,8 +68,9 @@ if pygpu:
theano.compile.shared_constructor(gpuarray_shared_constructor)
optdb.add_tags('gpuarray_opt', 'fast_run', 'fast_compile')
from .basic_ops import (GpuAlloc, GpuContiguous, GpuEye, GpuFromHost,
GpuJoin, GpuReshape, GpuSplit, HostFromGpu)
from .basic_ops import (GpuAlloc, GpuAllocEmpty, GpuContiguous, GpuEye,
GpuFromHost, GpuJoin, GpuReshape, GpuSplit,
HostFromGpu)
from .basic_ops import host_from_gpu, GpuFromHost
from .elemwise import GpuElemwise
from .subtensor import (GpuSubtensor, GpuIncSubtensor,
......
......@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
in2out(scan_merge_inouts),
ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3)
"""
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import logging
import copy
......@@ -66,21 +57,29 @@ import numpy
import theano
from theano import tensor
from theano.tensor import opt, get_scalar_constant_value
from theano.tensor import opt, get_scalar_constant_value, Alloc, AllocEmpty
from theano import gof
from theano.compat import OrderedDict
from six import integer_types, iteritems
from six.moves import xrange
from theano.gof.opt import Optimizer
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.compile import optdb
from theano.compile.function_module import deep_copy_op
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.gof.opt import Optimizer
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.scan_module import scan_op
from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import equal_computations, find_up, \
scan_args
from theano.scan_module.scan_utils import equal_computations, find_up, scan_args
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
# Logging function for sending warning or info
......@@ -151,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(node):
for idx in xrange(op.n_seqs):
node_inp = node.inputs[idx + 1]
if (isinstance(node_inp, tensor.TensorConstant) and
node_inp.tag.unique_value is not None):
node_inp.tag.unique_value is not None):
try:
# This works if input is a constant that has all entries
# equal
......@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs)
local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in \
enumerate(clean_outputs)])
local_fgraph_outs_map = dict([(v, k) for k, v in
enumerate(clean_outputs)])
to_remove_set = set()
to_replace_set = set()
to_replace_map = OrderedDict()
nto_replace = 0
def add_to_replace(y):
to_replace_set.add(y)
to_replace_map[y] = add_to_replace.n
add_to_replace.n +=1
add_to_replace.n += 1
add_to_replace.n = 0
replace_with_in = []
......@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)])
inner_non_seqs_map = dict([(v, k) for k, v in
enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs)
......@@ -275,17 +274,17 @@ class PushOutNonSeqScan(gof.Optimizer):
assert len(inner_seqs) == len(outer_seqs)
for nd in local_fgraph_topo:
if (# we haven't already looked at this node
nd not in to_remove_set and
all([((x in inner_non_seqs_set) or
(x.owner in to_remove_set) or
isinstance(x, tensor.Constant))
for x in nd.inputs]) and
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp)):
if ( # we haven't already looked at this node
nd not in to_remove_set and
all([((x in inner_non_seqs_set) or
(x.owner in to_remove_set) or
isinstance(x, tensor.Constant))
for x in nd.inputs]) and
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp)):
# We have a candidate node to removable
# Step 1. Reconstruct it on outside
......@@ -337,11 +336,11 @@ class PushOutNonSeqScan(gof.Optimizer):
to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items():
if (# If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
replace_with_in[idx].type == out.type and
out in to_keep_set and
out.owner not in existent_nodes_set):
if ( # If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
replace_with_in[idx].type == out.type and
out in to_keep_set and
out.owner not in existent_nodes_set):
clean_to_replace.append(out)
clean_replace_with_in.append(replace_with_in[idx])
clean_replace_with_out.append(replace_with_out[idx])
......@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs)
local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v,k) for k,v in \
enumerate(clean_outputs)])
local_fgraph_outs_map = dict([(v, k) for k, v in
enumerate(clean_outputs)])
to_remove_set = set()
to_replace_set = set()
to_replace_map = OrderedDict()
nto_replace = 0
def add_to_replace(y):
to_replace_set.add(y)
......@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)])
inner_non_seqs_map = dict([(v, k) for k, v in
enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs)
inner_seqs_map = dict([(v,k) for k,v in enumerate(inner_seqs)])
inner_seqs_map = dict([(v, k) for k, v in
enumerate(inner_seqs)])
outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
......@@ -488,7 +488,7 @@ class PushOutSeqScan(gof.Optimizer):
(x.owner in to_remove_set) or
isinstance(x, tensor.Constant) or
(x in inner_seqs_set) for x in nd.inputs]) and
isinstance(nd.op, theano.tensor.Elemwise)):
isinstance(nd.op, theano.tensor.Elemwise)):
outside_ins = []
depends_on_seqs = False
......@@ -538,7 +538,7 @@ class PushOutSeqScan(gof.Optimizer):
elif (nd not in to_remove_set and
isinstance(nd.op, theano.tensor.DimShuffle) and
(nd.inputs[0] in inner_seqs_set or
nd.inputs[0].owner in to_remove_set)):
nd.inputs[0].owner in to_remove_set)):
to_remove_set.add(nd)
x = nd.inputs[0]
......@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer):
to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items():
if (out in to_keep_set
and out.owner not in existent_nodes_set
# If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
and replace_with_in[idx].type == out.type):
if (out in to_keep_set and out.owner not in existent_nodes_set and
# If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
replace_with_in[idx].type == out.type):
clean_to_replace.append(out)
clean_replace_with_in.append(replace_with_in[idx])
......@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer):
not x.op.as_while)]
for node in nodelist:
# Process the node as long as something gets optimized
while node != None:
while node is not None:
node = self.process_node(fgraph, node)
def process_node(self, fgraph, node):
......@@ -702,7 +701,7 @@ class PushOutScanOutput(gof.Optimizer):
local_fgraph_topo = local_fgraph.toposort()
for nd in local_fgraph_topo:
if (isinstance(nd.op, theano.tensor.Dot) and
nd.out in args.inner_out_nit_sot):
nd.out in args.inner_out_nit_sot):
"""
The following optimization involves pushing out, after the
scan, a Dot whose output is nitsot (not feed back to the inner
......@@ -737,8 +736,8 @@ class PushOutScanOutput(gof.Optimizer):
(nd.inputs[0] in args.inner_in_non_seqs or
isinstance(nd.inputs[0], tensor.Constant)) and
nd.inputs[1].ndim == 1 and
(nd.inputs[1] in args.inner_in_seqs or
nd.inputs[1] not in args.inner_inputs)):
(nd.inputs[1] in args.inner_in_seqs or
nd.inputs[1] not in args.inner_inputs)):
valid_inputs = True
idx_matrix_input = 0
......@@ -778,11 +777,10 @@ class PushOutScanOutput(gof.Optimizer):
outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
# Modify the outer graph to add the outer Dot
fgraph.replace_all([
(new_scan_args.outer_out_nit_sot[
dot_out_nitsot_idx],
outer_dot_output)],
reason="scanOp_pushout_output")
fgraph.replace_all(
[(new_scan_args.outer_out_nit_sot[dot_out_nitsot_idx],
outer_dot_output)],
reason="scanOp_pushout_output")
break
......@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer):
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[
sitsot_idx])
dot_in_idx = 1 - sitsot_in_idx # 0 if sitsot_in_idx==1,
# 1 if sitsot_in_idx==0
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx]
if (dot_input.owner is not None and
......@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer):
len(dot_input.clients) == 1 and
dot_input.owner.inputs[0].ndim == 2 and
dot_input.owner.inputs[1].ndim == 2 and
self.get_outer_ndim(dot_input.owner.inputs[0], args) \
== 3 and
self.get_outer_ndim(dot_input.owner.inputs[1], args) \
== 3):
self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3 and
self.get_outer_ndim(dot_input.owner.inputs[1], args) == 3):
# The optimization can be be applied in this case.
......@@ -829,36 +826,32 @@ class PushOutScanOutput(gof.Optimizer):
(outer_dot_inputs,
new_scan_node,
new_scan_args) = \
self.push_out_inner_vars(fgraph,
inner_dot_inputs,
node, args)
self.push_out_inner_vars(fgraph, inner_dot_inputs,
node, args)
# Collapse some of the dimensions of the tensors
# so that they become matrices. This is because a
# dot is usually faster on two large matrices than
# a bunch of small ones
outer_dot_inputs[0] = theano.tensor.flatten(
outer_dot_inputs[0].dimshuffle(1, 0, 2),
outdim=2)
outer_dot_inputs[0].dimshuffle(1, 0, 2), outdim=2)
shape_input1 = theano.tensor.shape(outer_dot_inputs[1])
outer_dot_inputs[1] =\
outer_dot_inputs[1].reshape((shape_input1[0] *
shape_input1[1],
shape_input1[2]))
outer_dot_inputs[1].reshape((shape_input1[0] *
shape_input1[1],
shape_input1[2]))
# Perform the dot on the newly obtained matrices and
# add the initial value
outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
init_value = \
new_scan_args.outer_in_sit_sot[sitsot_idx][0]
init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0]
replacement = outer_dot_output + init_value
# Alter the outer graph to use the output of the
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
outer_sitsot = \
new_scan_args.outer_out_sit_sot[sitsot_idx]
outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
subtensor_node = outer_sitsot.clients[0][0]
outer_sitsot_last_step = subtensor_node.outputs[0]
......@@ -883,12 +876,12 @@ class PushOutScanOutput(gof.Optimizer):
if len(outer_var.clients) == 1:
client = outer_var.clients[0][0]
if (client != 'output' and
isinstance(client.op, theano.tensor.Subtensor)):
if (client != 'output' and isinstance(client.op,
theano.tensor.Subtensor)):
lst = theano.tensor.subtensor.get_idx_list(
client.inputs, client.op.idx_list)
if (len(lst) == 1 and
theano.tensor.extract_constant(lst[0]) == -1):
theano.tensor.extract_constant(lst[0]) == -1):
return True
return False
......@@ -898,7 +891,7 @@ class PushOutScanOutput(gof.Optimizer):
# Given a variable, determine the number of dimension it would have if
# it was pushed out of scan
if (var in scan_args.inner_in_non_seqs or
isinstance(var, theano.Constant)):
isinstance(var, theano.Constant)):
outer_ndim = var.ndim
else:
......@@ -990,8 +983,8 @@ class PushOutScanOutput(gof.Optimizer):
len(old_scan_args.outer_out_shared))
new_node_old_outputs = (
new_scan_node.outputs[:new_node_new_outputs_idx] +
new_scan_node.outputs[new_node_new_outputs_idx+nb_new_outs:])
new_scan_node.outputs[:new_node_new_outputs_idx] +
new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs:])
fgraph.replace_all_validate_remove(
list(zip(old_scan_node.outputs, new_node_old_outputs)),
......@@ -1017,7 +1010,7 @@ class ScanInplaceOptimizer(Optimizer):
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler())
def attempt_scan_inplace(self, fgraph, node, output_indices):
def attempt_scan_inplace(self, fgraph, node, output_indices, alloc_ops):
"""Attempts to replace a Scan node by one which computes the specified
outputs inplace.
......@@ -1029,6 +1022,10 @@ class ScanInplaceOptimizer(Optimizer):
Scan node to replace by an inplace version
output_indices : list of integers
Indices of the outputs to attempt to compute inplace
alloc_ops : list of Op classes
Classes that represent operation that allocate new memory and
that the optimization should duplicate so it can operate inplace
on them.
"""
op = node.op
......@@ -1049,6 +1046,14 @@ class ScanInplaceOptimizer(Optimizer):
ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs)
# In `ls`, duplicate any input which has more then one client and is
# the output of an eligible allocation op
for i in range(len(ls)):
inp = ls[i]
if (len(inp.clients) > 1 and inp.owner and
isinstance(inp.owner.op, alloc_ops)):
ls[i] = inp.owner.op(*inp.owner.inputs)
n_outs = len(ls)
for idx in xrange(n_outs):
if ls[idx] in ls[:idx]:
......@@ -1079,6 +1084,21 @@ class ScanInplaceOptimizer(Optimizer):
def apply(self, fgraph):
# Depending on the values of gpu_flag and gpua_flag, get the list of
# memory allocation ops that the optimization should be able to handle
alloc_ops = (Alloc, AllocEmpty)
if self.gpu_flag:
alloc_ops += (theano.sandbox.cuda.GpuAlloc,
theano.sandbox.cuda.GpuAllocEmpty)
if self.gpua_flag:
# gpuarray might be imported but not its GpuAlloc and
# GpuAllopEmpty ops.
try:
alloc_ops += (theano.sandbox.gpuarray.GpuAlloc,
theano.sandbox.gpuarray.GpuAllocEmpty)
except:
pass
nodes = fgraph.toposort()[::-1]
scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and
......@@ -1101,7 +1121,18 @@ class ScanInplaceOptimizer(Optimizer):
out_indices = []
for out_idx in range(n_outs):
inp_idx = 1 + op.n_seqs + out_idx
inp = original_node.inputs[inp_idx]
# If the input is from an eligible allocation node, attempt to
# be inplace on it, even if other nodes are modifying it
# inplace.
if inp.owner and isinstance(inp.owner.op, alloc_ops):
out_indices.append(out_idx)
continue
# If the input is not from an eligible allocation node, only
# attempt to be inplace on it if nothing else is currently
# inplace on it.
input_used_inplace = False
for c in original_node.inputs[inp_idx].clients:
client = c[0]
......@@ -1122,14 +1153,15 @@ class ScanInplaceOptimizer(Optimizer):
out_indices.append(out_idx)
node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx],
out_indices)
out_indices, alloc_ops)
if node is original_node:
# Making the scan compute all plausible recurrent outputs
# inplace has failed. Attempt all plausible recurrent output
# individually.
for pos in out_indices:
node = self.attempt_scan_inplace(fgraph, node, [pos])
node = self.attempt_scan_inplace(fgraph, node, [pos],
alloc_ops)
class ScanSaveMem(gof.Optimizer):
......@@ -1242,7 +1274,7 @@ class ScanSaveMem(gof.Optimizer):
for cl, _ in out.clients:
# 2.1 outputs of the function
#=> output needs all its intermediate values
# => output needs all its intermediate values
if type(cl) == str:
# if the node is actually an output, then
# we need to store the entire thing
......@@ -1250,20 +1282,20 @@ class ScanSaveMem(gof.Optimizer):
slices[i] = None
break
# 2.2 non-subtensor nodes
#=> output needs all its intermediate values
# => output needs all its intermediate values
elif not isinstance(cl.op, tensor.Subtensor):
global_nsteps = None
slices[i] = None
break
# 2.3 subtensor nodes
#=> output might need to store just a subset of its values
# => output might need to store just a subset of its values
else:
# 2.3.1 extract idx list of subtensor
this_slice = tensor.get_idx_list(cl.inputs,
cl.op.idx_list)
if this_slice is None:
# if unable to extract idx_list
#=> outputs needs all its intermediate values
# => outputs needs all its intermediate values
global_nsteps = None
slices[i] = None
break
......@@ -1284,7 +1316,7 @@ class ScanSaveMem(gof.Optimizer):
slices[i] += [(cf_slice, this_slice)]
if (isinstance(this_slice[0], slice) and
this_slice[0].stop is None):
this_slice[0].stop is None):
global_nsteps = None
if isinstance(cf_slice[0], slice):
stop = tensor.basic.extract_constant(cf_slice[0].stop)
......@@ -1374,7 +1406,7 @@ class ScanSaveMem(gof.Optimizer):
break
if (isinstance(this_slice[0], slice) and
this_slice[0].start is None):
this_slice[0].start is None):
store_steps[i] = 0
break
......@@ -1406,8 +1438,7 @@ class ScanSaveMem(gof.Optimizer):
# for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated.
prealloc_outs = \
theano.config.scan.allow_output_prealloc
prealloc_outs = theano.config.scan.allow_output_prealloc
first_mitsot_idx = node.op.n_mit_mot
last_sitsot_idx = (node.op.n_mit_mot +
......@@ -1433,7 +1464,7 @@ class ScanSaveMem(gof.Optimizer):
# currently.
# pval = pre_greedy_local_optimizer(list_opt_slice,
# pval)
#pval = pre_constant_merge([pval])[0]
# pval = pre_constant_merge([pval])[0]
# if (isinstance(pval, theano.tensor.TensorConstant)
# and
# pval.dtype.startswith('int')):
......@@ -1554,7 +1585,7 @@ class ScanSaveMem(gof.Optimizer):
nw_steps)
nw_inputs[in_idx] = nw_input
else:
nw_input = nw_inputs[in_idx][:(initl+nw_steps)]
nw_input = nw_inputs[in_idx][:(initl + nw_steps)]
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
in_idx = offset + idx + op.n_shared_outs
......@@ -1563,7 +1594,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \
scan_utils.compress_outs(op, not_required, nw_inputs)
scan_utils.compress_outs(op, not_required, nw_inputs)
inv_compress_map = OrderedDict()
for k, v in iteritems(compress_map):
inv_compress_map[v] = k
......@@ -1633,15 +1664,15 @@ class ScanSaveMem(gof.Optimizer):
start = (cnf_slice[0].start - nw_steps -
init_l[pos] + store_steps[pos])
if (cnf_slice[0].stop is not None and
cnf_slice[0].stop != maxsize):
cnf_slice[0].stop != maxsize):
stop = (cnf_slice[0].stop - nw_steps -
init_l[pos] + store_steps[pos])
else:
stop = None
nw_slice = ((slice(sanitize(start),
sanitize(stop),
sanitize(cnf_slice[0].step)),)
+ tuple(old_slices[1:]))
sanitize(cnf_slice[0].step)),) +
tuple(old_slices[1:]))
else:
position = (cnf_slice[0] - nw_steps -
......@@ -1662,8 +1693,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.9. Get replace pairs for all other nodes
if flag_store or global_nsteps is not None:
for idx, o in enumerate(node.outputs):
if not (idx in replaced_outs) and \
not idx in not_required:
if not (idx in replaced_outs) and idx not in not_required:
nw_pos = compress_map[idx]
old_new += [(o, new_outs[nw_pos])]
# Check if the new outputs depend on the old scan node
......@@ -1808,7 +1838,7 @@ class ScanMerge(gof.Optimizer):
flat_inner_outs = sum(inner_outs[idx], [])
# clone
flat_inner_ins, flat_inner_outs = scan_utils.reconstruct_graph(
flat_inner_ins, flat_inner_outs)
flat_inner_ins, flat_inner_outs)
# split the new inner variables again in seq, mitmot, etc.
new_inner_ins = []
count = 0
......@@ -2072,8 +2102,8 @@ def scan_merge_inouts(node):
# because they could have different sizes, and the corresponding
# outer outputs cannot be merged in that case.
for s_outer_i, s_inner_o, s_outer_o in seen:
if (equal_computations([inner_o], [s_inner_o], left, right)
and outer_i == s_outer_i):
if (equal_computations([inner_o], [s_inner_o], left, right) and
outer_i == s_outer_i):
return s_outer_o
seen.append((outer_i, inner_o, outer_o))
return outer_o
......@@ -2116,9 +2146,10 @@ def scan_merge_inouts(node):
na.outer_out_mit_mot,
na.mit_mot_out_slices):
for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen:
if (osl == sosl
and equal_computations(inner_omm, s_inner_omm, left, right)
and outer_imm == s_outer_imm):
if (osl == sosl and
equal_computations(inner_omm, s_inner_omm, left, right) and
outer_imm == s_outer_imm):
new_outer_out_mit_mot.append(s_outer_omm)
break
else:
......@@ -2168,17 +2199,15 @@ class PushOutDot1(gof.Optimizer):
inp in out.owner.inputs and
len(outer_out.clients) == 1 and
not isinstance(outer_out.clients[0][0], str) and
isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor)
and outer_out.clients[0][0].op.idx_list == (-1,)):
isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor) and
outer_out.clients[0][0].op.idx_list == (-1,)):
x = out.owner.inputs[0]
if x == inp:
x = out.owner.inputs[1]
# We need to check if x is the result of an outer product
if (x.owner and
isinstance(x.owner.op, theano.tensor.Dot) and
x.owner.inputs[0].ndim == 2 and
x.owner.inputs[1].ndim == 2):
if (x.owner and isinstance(x.owner.op, theano.tensor.Dot) and
x.owner.inputs[0].ndim == 2 and x.owner.inputs[1].ndim == 2):
# We need to check if any of the inputs are a sequence
inp1 = x.owner.inputs[0]
......@@ -2219,18 +2248,17 @@ class PushOutDot1(gof.Optimizer):
new_info = op.info.copy()
st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info['tap_array'] = (\
new_info['tap_array'][:st + idx] +
new_info['tap_array'][st +
idx + 1:])
new_info['tap_array'] = (
new_info['tap_array'][:st + idx] +
new_info['tap_array'][st + idx + 1:])
new_info['n_sit_sot'] -= 1
new_info['n_nit_sot'] += 1
inner_sitsot = inner_sitsot[:idx] + \
inner_sitsot[idx + 1:]
outer_sitsot = outer_sitsot[:idx] + \
outer_sitsot[idx + 1:]
inner_sitsot_outs = inner_sitsot_outs[:idx] +\
inner_sitsot_outs[idx + 1:]
inner_sitsot = (inner_sitsot[:idx] +
inner_sitsot[idx + 1:])
outer_sitsot = (outer_sitsot[:idx] +
outer_sitsot[idx + 1:])
inner_sitsot_outs = (inner_sitsot_outs[:idx] +
inner_sitsot_outs[idx + 1:])
# add n_steps as the length
inner_nitsot_outs.append(new_scan_out)
......@@ -2246,8 +2274,8 @@ class PushOutDot1(gof.Optimizer):
inner_nitsot_outs +
inner_shared_outs)
new_inner_inps, new_inner_outs =\
scan_utils.reconstruct_graph(
_new_inner_inps, _new_inner_outs)
scan_utils.reconstruct_graph(_new_inner_inps,
_new_inner_outs)
new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
new_info)
_scan_inputs = ([node.inputs[0]] +
......@@ -2267,11 +2295,7 @@ class PushOutDot1(gof.Optimizer):
# We need now to pair correctly the new outputs
# with the old ones
outer_mitmot_outs = new_op.outer_mitmot_outs(new_outs)
outer_mitsot_outs = new_op.outer_mitsot_outs(new_outs)
outer_sitsot_outs = new_op.outer_sitsot_outs(new_outs)
outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs)
outer_shared_outs = new_op.outer_shared_outs(new_outs)
_val = outer_nitsot_outs[-1]
outer_nitsot_outs = outer_nitsot_outs[:-1]
......@@ -2305,7 +2329,7 @@ class PushOutDot1(gof.Optimizer):
old_new = list(zip(node.outputs[:pos], new_outs[:pos]))
old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out))
old_new += list(zip(node.outputs[pos+1:],
old_new += list(zip(node.outputs[pos + 1:],
new_outs[pos:]))
fgraph.replace_all_validate_remove(
old_new, remove=[node], reason='scan_pushout_dot1')
......@@ -2374,61 +2398,61 @@ scan_seqopt1.register('scanOp_pushout_output',
scan_eqopt2.register('constant_folding_for_scan2',
opt.in2out(tensor.opt.constant_folding,
ignore_newtrees=True),
1,
'fast_run',
'scan')
opt.in2out(tensor.opt.constant_folding,
ignore_newtrees=True),
1,
'fast_run',
'scan')
scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs1',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
2,
'remove_constants_and_unused_inputs_scan',
'fast_run',
'scan')
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
2,
'remove_constants_and_unused_inputs_scan',
'fast_run',
'scan')
# after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later.
scan_eqopt2.register('scanOp_merge',
ScanMerge(),
4,
'fast_run',
'scan')
ScanMerge(),
4,
'fast_run',
'scan')
# After Merge optimization
scan_eqopt2.register('scanop_remove_constants_and_unused_inputs2',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
5,
'remove_constants_and_unused_inputs_scan',
'fast_run',
'scan')
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
5,
'remove_constants_and_unused_inputs_scan',
'fast_run',
'scan')
scan_eqopt2.register('scanOp_merge_inouts',
opt.in2out(scan_merge_inouts, ignore_newtrees=True),
6,
'scan_merge_inouts',
'fast_run',
'scan')
opt.in2out(scan_merge_inouts, ignore_newtrees=True),
6,
'scan_merge_inouts',
'fast_run',
'scan')
# Just before specialize to have the other optimization
# like constant folding being applied
# This don't introduce inplace.
scan_eqopt2.register('scanOp_save_mem',
ScanSaveMem(),
7,
'fast_run',
'scan')
ScanSaveMem(),
7,
'fast_run',
'scan')
# After everything else
scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs3',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
8,
'remove_constants_and_unused_inputs_scan',
'fast_run',
'scan')
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
8,
'remove_constants_and_unused_inputs_scan',
'fast_run',
'scan')
......@@ -164,7 +164,6 @@ whitelist_flake8 = [
"scan_module/scan_op.py",
"scan_module/scan_perform_ext.py",
"scan_module/__init__.py",
"scan_module/scan_opt.py",
"scan_module/tests/test_scan.py",
"scan_module/tests/test_scan_opt.py",
"misc/tests/test_may_share_memory.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论