提交 c8dc3dbe authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3524 from carriepl/scan_inplace_opt

Scan inplace opt
...@@ -316,7 +316,7 @@ if cuda_available: ...@@ -316,7 +316,7 @@ if cuda_available:
GpuDimShuffle, GpuCAReduce, GpuReshape, GpuContiguous, GpuDimShuffle, GpuCAReduce, GpuReshape, GpuContiguous,
GpuSubtensor, GpuIncSubtensor, GpuSubtensor, GpuIncSubtensor,
GpuAdvancedSubtensor1, GpuAdvancedIncSubtensor1, GpuAdvancedSubtensor1, GpuAdvancedIncSubtensor1,
GpuFlatten, GpuShape, GpuAlloc, GpuSplit, GpuFlatten, GpuShape, GpuAlloc, GpuAllocEmpty, GpuSplit,
GpuJoin, fscalar, fvector, fmatrix, frow, fcol, GpuJoin, fscalar, fvector, fmatrix, frow, fcol,
ftensor3, ftensor4, ftensor3, ftensor4,
scalar, vector, matrix, row, col, scalar, vector, matrix, row, col,
......
...@@ -68,8 +68,9 @@ if pygpu: ...@@ -68,8 +68,9 @@ if pygpu:
theano.compile.shared_constructor(gpuarray_shared_constructor) theano.compile.shared_constructor(gpuarray_shared_constructor)
optdb.add_tags('gpuarray_opt', 'fast_run', 'fast_compile') optdb.add_tags('gpuarray_opt', 'fast_run', 'fast_compile')
from .basic_ops import (GpuAlloc, GpuContiguous, GpuEye, GpuFromHost, from .basic_ops import (GpuAlloc, GpuAllocEmpty, GpuContiguous, GpuEye,
GpuJoin, GpuReshape, GpuSplit, HostFromGpu) GpuFromHost, GpuJoin, GpuReshape, GpuSplit,
HostFromGpu)
from .basic_ops import host_from_gpu, GpuFromHost from .basic_ops import host_from_gpu, GpuFromHost
from .elemwise import GpuElemwise from .elemwise import GpuElemwise
from .subtensor import (GpuSubtensor, GpuIncSubtensor, from .subtensor import (GpuSubtensor, GpuIncSubtensor,
......
...@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global). ...@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
in2out(scan_merge_inouts), in2out(scan_merge_inouts),
ScanSaveMem, ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3) in2out(remove_constants_and_unused_inputs_scan3)
""" """
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import logging import logging
import copy import copy
...@@ -66,21 +57,29 @@ import numpy ...@@ -66,21 +57,29 @@ import numpy
import theano import theano
from theano import tensor from theano import tensor
from theano.tensor import opt, get_scalar_constant_value from theano.tensor import opt, get_scalar_constant_value, Alloc, AllocEmpty
from theano import gof from theano import gof
from theano.compat import OrderedDict from theano.compat import OrderedDict
from six import integer_types, iteritems from six import integer_types, iteritems
from six.moves import xrange from six.moves import xrange
from theano.gof.opt import Optimizer
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.compile import optdb from theano.compile import optdb
from theano.compile.function_module import deep_copy_op from theano.compile.function_module import deep_copy_op
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.gof.opt import Optimizer
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.scan_module import scan_op from theano.scan_module import scan_op
from theano.scan_module import scan_utils from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import equal_computations, find_up, \ from theano.scan_module.scan_utils import equal_computations, find_up, scan_args
scan_args
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
# Logging function for sending warning or info # Logging function for sending warning or info
...@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs) clean_outputs)
local_fgraph_outs_set = set(clean_outputs) local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in \ local_fgraph_outs_map = dict([(v, k) for k, v in
enumerate(clean_outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_map = OrderedDict() to_replace_map = OrderedDict()
nto_replace = 0
def add_to_replace(y): def add_to_replace(y):
to_replace_set.add(y) to_replace_set.add(y)
to_replace_map[y] = add_to_replace.n to_replace_map[y] = add_to_replace.n
add_to_replace.n +=1 add_to_replace.n += 1
add_to_replace.n = 0 add_to_replace.n = 0
replace_with_in = [] replace_with_in = []
...@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)]) inner_non_seqs_map = dict([(v, k) for k, v in
enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
...@@ -275,7 +274,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -275,7 +274,7 @@ class PushOutNonSeqScan(gof.Optimizer):
assert len(inner_seqs) == len(outer_seqs) assert len(inner_seqs) == len(outer_seqs)
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if (# we haven't already looked at this node if ( # we haven't already looked at this node
nd not in to_remove_set and nd not in to_remove_set and
all([((x in inner_non_seqs_set) or all([((x in inner_non_seqs_set) or
(x.owner in to_remove_set) or (x.owner in to_remove_set) or
...@@ -337,7 +336,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -337,7 +336,7 @@ class PushOutNonSeqScan(gof.Optimizer):
to_keep_set.update(nd.inputs) to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if (# If types are different, conversion Op will be inserted, if ( # If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop. # and it may trigger an infinite loop.
replace_with_in[idx].type == out.type and replace_with_in[idx].type == out.type and
out in to_keep_set and out in to_keep_set and
...@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs) clean_outputs)
local_fgraph_outs_set = set(clean_outputs) local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v,k) for k,v in \ local_fgraph_outs_map = dict([(v, k) for k, v in
enumerate(clean_outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_map = OrderedDict() to_replace_map = OrderedDict()
nto_replace = 0
def add_to_replace(y): def add_to_replace(y):
to_replace_set.add(y) to_replace_set.add(y)
...@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)]) inner_non_seqs_map = dict([(v, k) for k, v in
enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs) inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs) inner_seqs_set = set(inner_seqs)
inner_seqs_map = dict([(v,k) for k,v in enumerate(inner_seqs)]) inner_seqs_map = dict([(v, k) for k, v in
enumerate(inner_seqs)])
outer_seqs = op.outer_seqs(node.inputs) outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs) assert len(inner_non_seqs) == len(outer_non_seqs)
...@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer):
to_keep_set.update(nd.inputs) to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if (out in to_keep_set if (out in to_keep_set and out.owner not in existent_nodes_set and
and out.owner not in existent_nodes_set
# If types are different, conversion Op will be inserted, # If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop. # and it may trigger an infinite loop.
and replace_with_in[idx].type == out.type): replace_with_in[idx].type == out.type):
clean_to_replace.append(out) clean_to_replace.append(out)
clean_replace_with_in.append(replace_with_in[idx]) clean_replace_with_in.append(replace_with_in[idx])
...@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer):
not x.op.as_while)] not x.op.as_while)]
for node in nodelist: for node in nodelist:
# Process the node as long as something gets optimized # Process the node as long as something gets optimized
while node != None: while node is not None:
node = self.process_node(fgraph, node) node = self.process_node(fgraph, node)
def process_node(self, fgraph, node): def process_node(self, fgraph, node):
...@@ -778,9 +777,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -778,9 +777,8 @@ class PushOutScanOutput(gof.Optimizer):
outer_dot_output = theano.tensor.dot(*outer_dot_inputs) outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
# Modify the outer graph to add the outer Dot # Modify the outer graph to add the outer Dot
fgraph.replace_all([ fgraph.replace_all(
(new_scan_args.outer_out_nit_sot[ [(new_scan_args.outer_out_nit_sot[dot_out_nitsot_idx],
dot_out_nitsot_idx],
outer_dot_output)], outer_dot_output)],
reason="scanOp_pushout_output") reason="scanOp_pushout_output")
...@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer):
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[ sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[
sitsot_idx]) sitsot_idx])
dot_in_idx = 1 - sitsot_in_idx # 0 if sitsot_in_idx==1, # 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
# 1 if sitsot_in_idx==0 dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx] dot_input = nd.inputs[dot_in_idx]
if (dot_input.owner is not None and if (dot_input.owner is not None and
...@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer):
len(dot_input.clients) == 1 and len(dot_input.clients) == 1 and
dot_input.owner.inputs[0].ndim == 2 and dot_input.owner.inputs[0].ndim == 2 and
dot_input.owner.inputs[1].ndim == 2 and dot_input.owner.inputs[1].ndim == 2 and
self.get_outer_ndim(dot_input.owner.inputs[0], args) \ self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3 and
== 3 and self.get_outer_ndim(dot_input.owner.inputs[1], args) == 3):
self.get_outer_ndim(dot_input.owner.inputs[1], args) \
== 3):
# The optimization can be be applied in this case. # The optimization can be be applied in this case.
...@@ -829,8 +826,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -829,8 +826,7 @@ class PushOutScanOutput(gof.Optimizer):
(outer_dot_inputs, (outer_dot_inputs,
new_scan_node, new_scan_node,
new_scan_args) = \ new_scan_args) = \
self.push_out_inner_vars(fgraph, self.push_out_inner_vars(fgraph, inner_dot_inputs,
inner_dot_inputs,
node, args) node, args)
# Collapse some of the dimensions of the tensors # Collapse some of the dimensions of the tensors
...@@ -838,8 +834,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -838,8 +834,7 @@ class PushOutScanOutput(gof.Optimizer):
# dot is usually faster on two large matrices than # dot is usually faster on two large matrices than
# a bunch of small ones # a bunch of small ones
outer_dot_inputs[0] = theano.tensor.flatten( outer_dot_inputs[0] = theano.tensor.flatten(
outer_dot_inputs[0].dimshuffle(1, 0, 2), outer_dot_inputs[0].dimshuffle(1, 0, 2), outdim=2)
outdim=2)
shape_input1 = theano.tensor.shape(outer_dot_inputs[1]) shape_input1 = theano.tensor.shape(outer_dot_inputs[1])
outer_dot_inputs[1] =\ outer_dot_inputs[1] =\
...@@ -850,15 +845,13 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -850,15 +845,13 @@ class PushOutScanOutput(gof.Optimizer):
# Perform the dot on the newly obtained matrices and # Perform the dot on the newly obtained matrices and
# add the initial value # add the initial value
outer_dot_output = theano.tensor.dot(*outer_dot_inputs) outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
init_value = \ init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0]
new_scan_args.outer_in_sit_sot[sitsot_idx][0]
replacement = outer_dot_output + init_value replacement = outer_dot_output + init_value
# Alter the outer graph to use the output of the # Alter the outer graph to use the output of the
# external Dot instead of the output of scan # external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot # Modify the outer graph to add the outer Dot
outer_sitsot = \ outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
new_scan_args.outer_out_sit_sot[sitsot_idx]
subtensor_node = outer_sitsot.clients[0][0] subtensor_node = outer_sitsot.clients[0][0]
outer_sitsot_last_step = subtensor_node.outputs[0] outer_sitsot_last_step = subtensor_node.outputs[0]
...@@ -883,8 +876,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -883,8 +876,8 @@ class PushOutScanOutput(gof.Optimizer):
if len(outer_var.clients) == 1: if len(outer_var.clients) == 1:
client = outer_var.clients[0][0] client = outer_var.clients[0][0]
if (client != 'output' and if (client != 'output' and isinstance(client.op,
isinstance(client.op, theano.tensor.Subtensor)): theano.tensor.Subtensor)):
lst = theano.tensor.subtensor.get_idx_list( lst = theano.tensor.subtensor.get_idx_list(
client.inputs, client.op.idx_list) client.inputs, client.op.idx_list)
if (len(lst) == 1 and if (len(lst) == 1 and
...@@ -991,7 +984,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -991,7 +984,7 @@ class PushOutScanOutput(gof.Optimizer):
new_node_old_outputs = ( new_node_old_outputs = (
new_scan_node.outputs[:new_node_new_outputs_idx] + new_scan_node.outputs[:new_node_new_outputs_idx] +
new_scan_node.outputs[new_node_new_outputs_idx+nb_new_outs:]) new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs:])
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
list(zip(old_scan_node.outputs, new_node_old_outputs)), list(zip(old_scan_node.outputs, new_node_old_outputs)),
...@@ -1017,7 +1010,7 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1017,7 +1010,7 @@ class ScanInplaceOptimizer(Optimizer):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler()) fgraph.attach_feature(DestroyHandler())
def attempt_scan_inplace(self, fgraph, node, output_indices): def attempt_scan_inplace(self, fgraph, node, output_indices, alloc_ops):
"""Attempts to replace a Scan node by one which computes the specified """Attempts to replace a Scan node by one which computes the specified
outputs inplace. outputs inplace.
...@@ -1029,6 +1022,10 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1029,6 +1022,10 @@ class ScanInplaceOptimizer(Optimizer):
Scan node to replace by an inplace version Scan node to replace by an inplace version
output_indices : list of integers output_indices : list of integers
Indices of the outputs to attempt to compute inplace Indices of the outputs to attempt to compute inplace
alloc_ops : list of Op classes
Classes that represent operation that allocate new memory and
that the optimization should duplicate so it can operate inplace
on them.
""" """
op = node.op op = node.op
...@@ -1049,6 +1046,14 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1049,6 +1046,14 @@ class ScanInplaceOptimizer(Optimizer):
ls_end += op.outer_nitsot(node.inputs) ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs) ls_end += op.outer_non_seqs(node.inputs)
# In `ls`, duplicate any input which has more then one client and is
# the output of an eligible allocation op
for i in range(len(ls)):
inp = ls[i]
if (len(inp.clients) > 1 and inp.owner and
isinstance(inp.owner.op, alloc_ops)):
ls[i] = inp.owner.op(*inp.owner.inputs)
n_outs = len(ls) n_outs = len(ls)
for idx in xrange(n_outs): for idx in xrange(n_outs):
if ls[idx] in ls[:idx]: if ls[idx] in ls[:idx]:
...@@ -1079,6 +1084,21 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1079,6 +1084,21 @@ class ScanInplaceOptimizer(Optimizer):
def apply(self, fgraph): def apply(self, fgraph):
# Depending on the values of gpu_flag and gpua_flag, get the list of
# memory allocation ops that the optimization should be able to handle
alloc_ops = (Alloc, AllocEmpty)
if self.gpu_flag:
alloc_ops += (theano.sandbox.cuda.GpuAlloc,
theano.sandbox.cuda.GpuAllocEmpty)
if self.gpua_flag:
# gpuarray might be imported but not its GpuAlloc and
# GpuAllopEmpty ops.
try:
alloc_ops += (theano.sandbox.gpuarray.GpuAlloc,
theano.sandbox.gpuarray.GpuAllocEmpty)
except:
pass
nodes = fgraph.toposort()[::-1] nodes = fgraph.toposort()[::-1]
scan_nodes = [x for x in nodes scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and if (isinstance(x.op, scan_op.Scan) and
...@@ -1101,7 +1121,18 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1101,7 +1121,18 @@ class ScanInplaceOptimizer(Optimizer):
out_indices = [] out_indices = []
for out_idx in range(n_outs): for out_idx in range(n_outs):
inp_idx = 1 + op.n_seqs + out_idx inp_idx = 1 + op.n_seqs + out_idx
inp = original_node.inputs[inp_idx]
# If the input is from an eligible allocation node, attempt to
# be inplace on it, even if other nodes are modifying it
# inplace.
if inp.owner and isinstance(inp.owner.op, alloc_ops):
out_indices.append(out_idx)
continue
# If the input is not from an eligible allocation node, only
# attempt to be inplace on it if nothing else is currently
# inplace on it.
input_used_inplace = False input_used_inplace = False
for c in original_node.inputs[inp_idx].clients: for c in original_node.inputs[inp_idx].clients:
client = c[0] client = c[0]
...@@ -1122,14 +1153,15 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1122,14 +1153,15 @@ class ScanInplaceOptimizer(Optimizer):
out_indices.append(out_idx) out_indices.append(out_idx)
node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx], node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx],
out_indices) out_indices, alloc_ops)
if node is original_node: if node is original_node:
# Making the scan compute all plausible recurrent outputs # Making the scan compute all plausible recurrent outputs
# inplace has failed. Attempt all plausible recurrent output # inplace has failed. Attempt all plausible recurrent output
# individually. # individually.
for pos in out_indices: for pos in out_indices:
node = self.attempt_scan_inplace(fgraph, node, [pos]) node = self.attempt_scan_inplace(fgraph, node, [pos],
alloc_ops)
class ScanSaveMem(gof.Optimizer): class ScanSaveMem(gof.Optimizer):
...@@ -1242,7 +1274,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1242,7 +1274,7 @@ class ScanSaveMem(gof.Optimizer):
for cl, _ in out.clients: for cl, _ in out.clients:
# 2.1 outputs of the function # 2.1 outputs of the function
#=> output needs all its intermediate values # => output needs all its intermediate values
if type(cl) == str: if type(cl) == str:
# if the node is actually an output, then # if the node is actually an output, then
# we need to store the entire thing # we need to store the entire thing
...@@ -1250,20 +1282,20 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1250,20 +1282,20 @@ class ScanSaveMem(gof.Optimizer):
slices[i] = None slices[i] = None
break break
# 2.2 non-subtensor nodes # 2.2 non-subtensor nodes
#=> output needs all its intermediate values # => output needs all its intermediate values
elif not isinstance(cl.op, tensor.Subtensor): elif not isinstance(cl.op, tensor.Subtensor):
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
# 2.3 subtensor nodes # 2.3 subtensor nodes
#=> output might need to store just a subset of its values # => output might need to store just a subset of its values
else: else:
# 2.3.1 extract idx list of subtensor # 2.3.1 extract idx list of subtensor
this_slice = tensor.get_idx_list(cl.inputs, this_slice = tensor.get_idx_list(cl.inputs,
cl.op.idx_list) cl.op.idx_list)
if this_slice is None: if this_slice is None:
# if unable to extract idx_list # if unable to extract idx_list
#=> outputs needs all its intermediate values # => outputs needs all its intermediate values
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
...@@ -1406,8 +1438,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1406,8 +1438,7 @@ class ScanSaveMem(gof.Optimizer):
# for mitsots and sitsots (because mitmots are not # for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if # currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated. # the pre-allocation mechanism is activated.
prealloc_outs = \ prealloc_outs = theano.config.scan.allow_output_prealloc
theano.config.scan.allow_output_prealloc
first_mitsot_idx = node.op.n_mit_mot first_mitsot_idx = node.op.n_mit_mot
last_sitsot_idx = (node.op.n_mit_mot + last_sitsot_idx = (node.op.n_mit_mot +
...@@ -1433,7 +1464,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1433,7 +1464,7 @@ class ScanSaveMem(gof.Optimizer):
# currently. # currently.
# pval = pre_greedy_local_optimizer(list_opt_slice, # pval = pre_greedy_local_optimizer(list_opt_slice,
# pval) # pval)
#pval = pre_constant_merge([pval])[0] # pval = pre_constant_merge([pval])[0]
# if (isinstance(pval, theano.tensor.TensorConstant) # if (isinstance(pval, theano.tensor.TensorConstant)
# and # and
# pval.dtype.startswith('int')): # pval.dtype.startswith('int')):
...@@ -1554,7 +1585,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1554,7 +1585,7 @@ class ScanSaveMem(gof.Optimizer):
nw_steps) nw_steps)
nw_inputs[in_idx] = nw_input nw_inputs[in_idx] = nw_input
else: else:
nw_input = nw_inputs[in_idx][:(initl+nw_steps)] nw_input = nw_inputs[in_idx][:(initl + nw_steps)]
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
in_idx = offset + idx + op.n_shared_outs in_idx = offset + idx + op.n_shared_outs
...@@ -1640,8 +1671,8 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1640,8 +1671,8 @@ class ScanSaveMem(gof.Optimizer):
stop = None stop = None
nw_slice = ((slice(sanitize(start), nw_slice = ((slice(sanitize(start),
sanitize(stop), sanitize(stop),
sanitize(cnf_slice[0].step)),) sanitize(cnf_slice[0].step)),) +
+ tuple(old_slices[1:])) tuple(old_slices[1:]))
else: else:
position = (cnf_slice[0] - nw_steps - position = (cnf_slice[0] - nw_steps -
...@@ -1662,8 +1693,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1662,8 +1693,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.9. Get replace pairs for all other nodes # 3.9. Get replace pairs for all other nodes
if flag_store or global_nsteps is not None: if flag_store or global_nsteps is not None:
for idx, o in enumerate(node.outputs): for idx, o in enumerate(node.outputs):
if not (idx in replaced_outs) and \ if not (idx in replaced_outs) and idx not in not_required:
not idx in not_required:
nw_pos = compress_map[idx] nw_pos = compress_map[idx]
old_new += [(o, new_outs[nw_pos])] old_new += [(o, new_outs[nw_pos])]
# Check if the new outputs depend on the old scan node # Check if the new outputs depend on the old scan node
...@@ -2072,8 +2102,8 @@ def scan_merge_inouts(node): ...@@ -2072,8 +2102,8 @@ def scan_merge_inouts(node):
# because they could have different sizes, and the corresponding # because they could have different sizes, and the corresponding
# outer outputs cannot be merged in that case. # outer outputs cannot be merged in that case.
for s_outer_i, s_inner_o, s_outer_o in seen: for s_outer_i, s_inner_o, s_outer_o in seen:
if (equal_computations([inner_o], [s_inner_o], left, right) if (equal_computations([inner_o], [s_inner_o], left, right) and
and outer_i == s_outer_i): outer_i == s_outer_i):
return s_outer_o return s_outer_o
seen.append((outer_i, inner_o, outer_o)) seen.append((outer_i, inner_o, outer_o))
return outer_o return outer_o
...@@ -2116,9 +2146,10 @@ def scan_merge_inouts(node): ...@@ -2116,9 +2146,10 @@ def scan_merge_inouts(node):
na.outer_out_mit_mot, na.outer_out_mit_mot,
na.mit_mot_out_slices): na.mit_mot_out_slices):
for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen: for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen:
if (osl == sosl if (osl == sosl and
and equal_computations(inner_omm, s_inner_omm, left, right) equal_computations(inner_omm, s_inner_omm, left, right) and
and outer_imm == s_outer_imm): outer_imm == s_outer_imm):
new_outer_out_mit_mot.append(s_outer_omm) new_outer_out_mit_mot.append(s_outer_omm)
break break
else: else:
...@@ -2168,17 +2199,15 @@ class PushOutDot1(gof.Optimizer): ...@@ -2168,17 +2199,15 @@ class PushOutDot1(gof.Optimizer):
inp in out.owner.inputs and inp in out.owner.inputs and
len(outer_out.clients) == 1 and len(outer_out.clients) == 1 and
not isinstance(outer_out.clients[0][0], str) and not isinstance(outer_out.clients[0][0], str) and
isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor) isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor) and
and outer_out.clients[0][0].op.idx_list == (-1,)): outer_out.clients[0][0].op.idx_list == (-1,)):
x = out.owner.inputs[0] x = out.owner.inputs[0]
if x == inp: if x == inp:
x = out.owner.inputs[1] x = out.owner.inputs[1]
# We need to check if x is the result of an outer product # We need to check if x is the result of an outer product
if (x.owner and if (x.owner and isinstance(x.owner.op, theano.tensor.Dot) and
isinstance(x.owner.op, theano.tensor.Dot) and x.owner.inputs[0].ndim == 2 and x.owner.inputs[1].ndim == 2):
x.owner.inputs[0].ndim == 2 and
x.owner.inputs[1].ndim == 2):
# We need to check if any of the inputs are a sequence # We need to check if any of the inputs are a sequence
inp1 = x.owner.inputs[0] inp1 = x.owner.inputs[0]
...@@ -2219,18 +2248,17 @@ class PushOutDot1(gof.Optimizer): ...@@ -2219,18 +2248,17 @@ class PushOutDot1(gof.Optimizer):
new_info = op.info.copy() new_info = op.info.copy()
st = len(op.mitmot_taps()) + len(op.mitsot_taps()) st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info['tap_array'] = (\ new_info['tap_array'] = (
new_info['tap_array'][:st + idx] + new_info['tap_array'][:st + idx] +
new_info['tap_array'][st + new_info['tap_array'][st + idx + 1:])
idx + 1:])
new_info['n_sit_sot'] -= 1 new_info['n_sit_sot'] -= 1
new_info['n_nit_sot'] += 1 new_info['n_nit_sot'] += 1
inner_sitsot = inner_sitsot[:idx] + \ inner_sitsot = (inner_sitsot[:idx] +
inner_sitsot[idx + 1:] inner_sitsot[idx + 1:])
outer_sitsot = outer_sitsot[:idx] + \ outer_sitsot = (outer_sitsot[:idx] +
outer_sitsot[idx + 1:] outer_sitsot[idx + 1:])
inner_sitsot_outs = inner_sitsot_outs[:idx] +\ inner_sitsot_outs = (inner_sitsot_outs[:idx] +
inner_sitsot_outs[idx + 1:] inner_sitsot_outs[idx + 1:])
# add n_steps as the length # add n_steps as the length
inner_nitsot_outs.append(new_scan_out) inner_nitsot_outs.append(new_scan_out)
...@@ -2246,8 +2274,8 @@ class PushOutDot1(gof.Optimizer): ...@@ -2246,8 +2274,8 @@ class PushOutDot1(gof.Optimizer):
inner_nitsot_outs + inner_nitsot_outs +
inner_shared_outs) inner_shared_outs)
new_inner_inps, new_inner_outs =\ new_inner_inps, new_inner_outs =\
scan_utils.reconstruct_graph( scan_utils.reconstruct_graph(_new_inner_inps,
_new_inner_inps, _new_inner_outs) _new_inner_outs)
new_op = scan_op.Scan(new_inner_inps, new_inner_outs, new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
new_info) new_info)
_scan_inputs = ([node.inputs[0]] + _scan_inputs = ([node.inputs[0]] +
...@@ -2267,11 +2295,7 @@ class PushOutDot1(gof.Optimizer): ...@@ -2267,11 +2295,7 @@ class PushOutDot1(gof.Optimizer):
# We need now to pair correctly the new outputs # We need now to pair correctly the new outputs
# with the old ones # with the old ones
outer_mitmot_outs = new_op.outer_mitmot_outs(new_outs)
outer_mitsot_outs = new_op.outer_mitsot_outs(new_outs)
outer_sitsot_outs = new_op.outer_sitsot_outs(new_outs)
outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs) outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs)
outer_shared_outs = new_op.outer_shared_outs(new_outs)
_val = outer_nitsot_outs[-1] _val = outer_nitsot_outs[-1]
outer_nitsot_outs = outer_nitsot_outs[:-1] outer_nitsot_outs = outer_nitsot_outs[:-1]
...@@ -2305,7 +2329,7 @@ class PushOutDot1(gof.Optimizer): ...@@ -2305,7 +2329,7 @@ class PushOutDot1(gof.Optimizer):
old_new = list(zip(node.outputs[:pos], new_outs[:pos])) old_new = list(zip(node.outputs[:pos], new_outs[:pos]))
old = node.outputs[pos].clients[0][0].outputs[0] old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out)) old_new.append((old, new_out))
old_new += list(zip(node.outputs[pos+1:], old_new += list(zip(node.outputs[pos + 1:],
new_outs[pos:])) new_outs[pos:]))
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
old_new, remove=[node], reason='scan_pushout_dot1') old_new, remove=[node], reason='scan_pushout_dot1')
......
...@@ -164,7 +164,6 @@ whitelist_flake8 = [ ...@@ -164,7 +164,6 @@ whitelist_flake8 = [
"scan_module/scan_op.py", "scan_module/scan_op.py",
"scan_module/scan_perform_ext.py", "scan_module/scan_perform_ext.py",
"scan_module/__init__.py", "scan_module/__init__.py",
"scan_module/scan_opt.py",
"scan_module/tests/test_scan.py", "scan_module/tests/test_scan.py",
"scan_module/tests/test_scan_opt.py", "scan_module/tests/test_scan_opt.py",
"misc/tests/test_may_share_memory.py", "misc/tests/test_may_share_memory.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论