提交 c8dc3dbe authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3524 from carriepl/scan_inplace_opt

Scan inplace opt
...@@ -316,7 +316,7 @@ if cuda_available: ...@@ -316,7 +316,7 @@ if cuda_available:
GpuDimShuffle, GpuCAReduce, GpuReshape, GpuContiguous, GpuDimShuffle, GpuCAReduce, GpuReshape, GpuContiguous,
GpuSubtensor, GpuIncSubtensor, GpuSubtensor, GpuIncSubtensor,
GpuAdvancedSubtensor1, GpuAdvancedIncSubtensor1, GpuAdvancedSubtensor1, GpuAdvancedIncSubtensor1,
GpuFlatten, GpuShape, GpuAlloc, GpuSplit, GpuFlatten, GpuShape, GpuAlloc, GpuAllocEmpty, GpuSplit,
GpuJoin, fscalar, fvector, fmatrix, frow, fcol, GpuJoin, fscalar, fvector, fmatrix, frow, fcol,
ftensor3, ftensor4, ftensor3, ftensor4,
scalar, vector, matrix, row, col, scalar, vector, matrix, row, col,
...@@ -341,7 +341,7 @@ def use(device, ...@@ -341,7 +341,7 @@ def use(device,
Parameters Parameters
---------- ----------
device : string device : string
"cpu", "gpu", "gpuN" (N is the device number to use). "cpu", "gpu", "gpuN" (N is the device number to use).
force force
Will always raise an exception if we can't use the gpu. Will always raise an exception if we can't use the gpu.
......
...@@ -68,8 +68,9 @@ if pygpu: ...@@ -68,8 +68,9 @@ if pygpu:
theano.compile.shared_constructor(gpuarray_shared_constructor) theano.compile.shared_constructor(gpuarray_shared_constructor)
optdb.add_tags('gpuarray_opt', 'fast_run', 'fast_compile') optdb.add_tags('gpuarray_opt', 'fast_run', 'fast_compile')
from .basic_ops import (GpuAlloc, GpuContiguous, GpuEye, GpuFromHost, from .basic_ops import (GpuAlloc, GpuAllocEmpty, GpuContiguous, GpuEye,
GpuJoin, GpuReshape, GpuSplit, HostFromGpu) GpuFromHost, GpuJoin, GpuReshape, GpuSplit,
HostFromGpu)
from .basic_ops import host_from_gpu, GpuFromHost from .basic_ops import host_from_gpu, GpuFromHost
from .elemwise import GpuElemwise from .elemwise import GpuElemwise
from .subtensor import (GpuSubtensor, GpuIncSubtensor, from .subtensor import (GpuSubtensor, GpuIncSubtensor,
......
...@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global). ...@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
in2out(scan_merge_inouts), in2out(scan_merge_inouts),
ScanSaveMem, ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3) in2out(remove_constants_and_unused_inputs_scan3)
""" """
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import logging import logging
import copy import copy
...@@ -66,21 +57,29 @@ import numpy ...@@ -66,21 +57,29 @@ import numpy
import theano import theano
from theano import tensor from theano import tensor
from theano.tensor import opt, get_scalar_constant_value from theano.tensor import opt, get_scalar_constant_value, Alloc, AllocEmpty
from theano import gof from theano import gof
from theano.compat import OrderedDict from theano.compat import OrderedDict
from six import integer_types, iteritems from six import integer_types, iteritems
from six.moves import xrange from six.moves import xrange
from theano.gof.opt import Optimizer
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.compile import optdb from theano.compile import optdb
from theano.compile.function_module import deep_copy_op from theano.compile.function_module import deep_copy_op
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.gof.opt import Optimizer
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.scan_module import scan_op from theano.scan_module import scan_op
from theano.scan_module import scan_utils from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import equal_computations, find_up, \ from theano.scan_module.scan_utils import equal_computations, find_up, scan_args
scan_args
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
# Logging function for sending warning or info # Logging function for sending warning or info
...@@ -151,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -151,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(node):
for idx in xrange(op.n_seqs): for idx in xrange(op.n_seqs):
node_inp = node.inputs[idx + 1] node_inp = node.inputs[idx + 1]
if (isinstance(node_inp, tensor.TensorConstant) and if (isinstance(node_inp, tensor.TensorConstant) and
node_inp.tag.unique_value is not None): node_inp.tag.unique_value is not None):
try: try:
# This works if input is a constant that has all entries # This works if input is a constant that has all entries
# equal # equal
...@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs) clean_outputs)
local_fgraph_outs_set = set(clean_outputs) local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in \ local_fgraph_outs_map = dict([(v, k) for k, v in
enumerate(clean_outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_map = OrderedDict() to_replace_map = OrderedDict()
nto_replace = 0
def add_to_replace(y): def add_to_replace(y):
to_replace_set.add(y) to_replace_set.add(y)
to_replace_map[y] = add_to_replace.n to_replace_map[y] = add_to_replace.n
add_to_replace.n +=1 add_to_replace.n += 1
add_to_replace.n = 0 add_to_replace.n = 0
replace_with_in = [] replace_with_in = []
...@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)]) inner_non_seqs_map = dict([(v, k) for k, v in
enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
...@@ -275,17 +274,17 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -275,17 +274,17 @@ class PushOutNonSeqScan(gof.Optimizer):
assert len(inner_seqs) == len(outer_seqs) assert len(inner_seqs) == len(outer_seqs)
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if (# we haven't already looked at this node if ( # we haven't already looked at this node
nd not in to_remove_set and nd not in to_remove_set and
all([((x in inner_non_seqs_set) or all([((x in inner_non_seqs_set) or
(x.owner in to_remove_set) or (x.owner in to_remove_set) or
isinstance(x, tensor.Constant)) isinstance(x, tensor.Constant))
for x in nd.inputs]) and for x in nd.inputs]) and
# we can do this because the assumption is that a # we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the # viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle .. # function and not somewhere in the middle ..
not isinstance(nd.op, theano.compile.ViewOp) and not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp)): not isinstance(nd.op, theano.compile.DeepCopyOp)):
# We have a candidate node to removable # We have a candidate node to removable
# Step 1. Reconstruct it on outside # Step 1. Reconstruct it on outside
...@@ -337,11 +336,11 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -337,11 +336,11 @@ class PushOutNonSeqScan(gof.Optimizer):
to_keep_set.update(nd.inputs) to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if (# If types are different, conversion Op will be inserted, if ( # If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop. # and it may trigger an infinite loop.
replace_with_in[idx].type == out.type and replace_with_in[idx].type == out.type and
out in to_keep_set and out in to_keep_set and
out.owner not in existent_nodes_set): out.owner not in existent_nodes_set):
clean_to_replace.append(out) clean_to_replace.append(out)
clean_replace_with_in.append(replace_with_in[idx]) clean_replace_with_in.append(replace_with_in[idx])
clean_replace_with_out.append(replace_with_out[idx]) clean_replace_with_out.append(replace_with_out[idx])
...@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs) clean_outputs)
local_fgraph_outs_set = set(clean_outputs) local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v,k) for k,v in \ local_fgraph_outs_map = dict([(v, k) for k, v in
enumerate(clean_outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_map = OrderedDict() to_replace_map = OrderedDict()
nto_replace = 0
def add_to_replace(y): def add_to_replace(y):
to_replace_set.add(y) to_replace_set.add(y)
...@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)]) inner_non_seqs_map = dict([(v, k) for k, v in
enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs) inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs) inner_seqs_set = set(inner_seqs)
inner_seqs_map = dict([(v,k) for k,v in enumerate(inner_seqs)]) inner_seqs_map = dict([(v, k) for k, v in
enumerate(inner_seqs)])
outer_seqs = op.outer_seqs(node.inputs) outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs) assert len(inner_non_seqs) == len(outer_non_seqs)
...@@ -488,7 +488,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -488,7 +488,7 @@ class PushOutSeqScan(gof.Optimizer):
(x.owner in to_remove_set) or (x.owner in to_remove_set) or
isinstance(x, tensor.Constant) or isinstance(x, tensor.Constant) or
(x in inner_seqs_set) for x in nd.inputs]) and (x in inner_seqs_set) for x in nd.inputs]) and
isinstance(nd.op, theano.tensor.Elemwise)): isinstance(nd.op, theano.tensor.Elemwise)):
outside_ins = [] outside_ins = []
depends_on_seqs = False depends_on_seqs = False
...@@ -538,7 +538,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -538,7 +538,7 @@ class PushOutSeqScan(gof.Optimizer):
elif (nd not in to_remove_set and elif (nd not in to_remove_set and
isinstance(nd.op, theano.tensor.DimShuffle) and isinstance(nd.op, theano.tensor.DimShuffle) and
(nd.inputs[0] in inner_seqs_set or (nd.inputs[0] in inner_seqs_set or
nd.inputs[0].owner in to_remove_set)): nd.inputs[0].owner in to_remove_set)):
to_remove_set.add(nd) to_remove_set.add(nd)
x = nd.inputs[0] x = nd.inputs[0]
...@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer):
to_keep_set.update(nd.inputs) to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if (out in to_keep_set if (out in to_keep_set and out.owner not in existent_nodes_set and
and out.owner not in existent_nodes_set # If types are different, conversion Op will be inserted,
# If types are different, conversion Op will be inserted, # and it may trigger an infinite loop.
# and it may trigger an infinite loop. replace_with_in[idx].type == out.type):
and replace_with_in[idx].type == out.type):
clean_to_replace.append(out) clean_to_replace.append(out)
clean_replace_with_in.append(replace_with_in[idx]) clean_replace_with_in.append(replace_with_in[idx])
...@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer):
not x.op.as_while)] not x.op.as_while)]
for node in nodelist: for node in nodelist:
# Process the node as long as something gets optimized # Process the node as long as something gets optimized
while node != None: while node is not None:
node = self.process_node(fgraph, node) node = self.process_node(fgraph, node)
def process_node(self, fgraph, node): def process_node(self, fgraph, node):
...@@ -702,7 +701,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -702,7 +701,7 @@ class PushOutScanOutput(gof.Optimizer):
local_fgraph_topo = local_fgraph.toposort() local_fgraph_topo = local_fgraph.toposort()
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if (isinstance(nd.op, theano.tensor.Dot) and if (isinstance(nd.op, theano.tensor.Dot) and
nd.out in args.inner_out_nit_sot): nd.out in args.inner_out_nit_sot):
""" """
The following optimization involves pushing out, after the The following optimization involves pushing out, after the
scan, a Dot whose output is nitsot (not feed back to the inner scan, a Dot whose output is nitsot (not feed back to the inner
...@@ -737,8 +736,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -737,8 +736,8 @@ class PushOutScanOutput(gof.Optimizer):
(nd.inputs[0] in args.inner_in_non_seqs or (nd.inputs[0] in args.inner_in_non_seqs or
isinstance(nd.inputs[0], tensor.Constant)) and isinstance(nd.inputs[0], tensor.Constant)) and
nd.inputs[1].ndim == 1 and nd.inputs[1].ndim == 1 and
(nd.inputs[1] in args.inner_in_seqs or (nd.inputs[1] in args.inner_in_seqs or
nd.inputs[1] not in args.inner_inputs)): nd.inputs[1] not in args.inner_inputs)):
valid_inputs = True valid_inputs = True
idx_matrix_input = 0 idx_matrix_input = 0
...@@ -778,11 +777,10 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -778,11 +777,10 @@ class PushOutScanOutput(gof.Optimizer):
outer_dot_output = theano.tensor.dot(*outer_dot_inputs) outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
# Modify the outer graph to add the outer Dot # Modify the outer graph to add the outer Dot
fgraph.replace_all([ fgraph.replace_all(
(new_scan_args.outer_out_nit_sot[ [(new_scan_args.outer_out_nit_sot[dot_out_nitsot_idx],
dot_out_nitsot_idx], outer_dot_output)],
outer_dot_output)], reason="scanOp_pushout_output")
reason="scanOp_pushout_output")
break break
...@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer):
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[ sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[
sitsot_idx]) sitsot_idx])
dot_in_idx = 1 - sitsot_in_idx # 0 if sitsot_in_idx==1, # 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
# 1 if sitsot_in_idx==0 dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx] dot_input = nd.inputs[dot_in_idx]
if (dot_input.owner is not None and if (dot_input.owner is not None and
...@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer):
len(dot_input.clients) == 1 and len(dot_input.clients) == 1 and
dot_input.owner.inputs[0].ndim == 2 and dot_input.owner.inputs[0].ndim == 2 and
dot_input.owner.inputs[1].ndim == 2 and dot_input.owner.inputs[1].ndim == 2 and
self.get_outer_ndim(dot_input.owner.inputs[0], args) \ self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3 and
== 3 and self.get_outer_ndim(dot_input.owner.inputs[1], args) == 3):
self.get_outer_ndim(dot_input.owner.inputs[1], args) \
== 3):
# The optimization can be be applied in this case. # The optimization can be be applied in this case.
...@@ -829,36 +826,32 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -829,36 +826,32 @@ class PushOutScanOutput(gof.Optimizer):
(outer_dot_inputs, (outer_dot_inputs,
new_scan_node, new_scan_node,
new_scan_args) = \ new_scan_args) = \
self.push_out_inner_vars(fgraph, self.push_out_inner_vars(fgraph, inner_dot_inputs,
inner_dot_inputs, node, args)
node, args)
# Collapse some of the dimensions of the tensors # Collapse some of the dimensions of the tensors
# so that they become matrices. This is because a # so that they become matrices. This is because a
# dot is usually faster on two large matrices than # dot is usually faster on two large matrices than
# a bunch of small ones # a bunch of small ones
outer_dot_inputs[0] = theano.tensor.flatten( outer_dot_inputs[0] = theano.tensor.flatten(
outer_dot_inputs[0].dimshuffle(1, 0, 2), outer_dot_inputs[0].dimshuffle(1, 0, 2), outdim=2)
outdim=2)
shape_input1 = theano.tensor.shape(outer_dot_inputs[1]) shape_input1 = theano.tensor.shape(outer_dot_inputs[1])
outer_dot_inputs[1] =\ outer_dot_inputs[1] =\
outer_dot_inputs[1].reshape((shape_input1[0] * outer_dot_inputs[1].reshape((shape_input1[0] *
shape_input1[1], shape_input1[1],
shape_input1[2])) shape_input1[2]))
# Perform the dot on the newly obtained matrices and # Perform the dot on the newly obtained matrices and
# add the initial value # add the initial value
outer_dot_output = theano.tensor.dot(*outer_dot_inputs) outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
init_value = \ init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0]
new_scan_args.outer_in_sit_sot[sitsot_idx][0]
replacement = outer_dot_output + init_value replacement = outer_dot_output + init_value
# Alter the outer graph to use the output of the # Alter the outer graph to use the output of the
# external Dot instead of the output of scan # external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot # Modify the outer graph to add the outer Dot
outer_sitsot = \ outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
new_scan_args.outer_out_sit_sot[sitsot_idx]
subtensor_node = outer_sitsot.clients[0][0] subtensor_node = outer_sitsot.clients[0][0]
outer_sitsot_last_step = subtensor_node.outputs[0] outer_sitsot_last_step = subtensor_node.outputs[0]
...@@ -883,12 +876,12 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -883,12 +876,12 @@ class PushOutScanOutput(gof.Optimizer):
if len(outer_var.clients) == 1: if len(outer_var.clients) == 1:
client = outer_var.clients[0][0] client = outer_var.clients[0][0]
if (client != 'output' and if (client != 'output' and isinstance(client.op,
isinstance(client.op, theano.tensor.Subtensor)): theano.tensor.Subtensor)):
lst = theano.tensor.subtensor.get_idx_list( lst = theano.tensor.subtensor.get_idx_list(
client.inputs, client.op.idx_list) client.inputs, client.op.idx_list)
if (len(lst) == 1 and if (len(lst) == 1 and
theano.tensor.extract_constant(lst[0]) == -1): theano.tensor.extract_constant(lst[0]) == -1):
return True return True
return False return False
...@@ -898,7 +891,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -898,7 +891,7 @@ class PushOutScanOutput(gof.Optimizer):
# Given a variable, determine the number of dimension it would have if # Given a variable, determine the number of dimension it would have if
# it was pushed out of scan # it was pushed out of scan
if (var in scan_args.inner_in_non_seqs or if (var in scan_args.inner_in_non_seqs or
isinstance(var, theano.Constant)): isinstance(var, theano.Constant)):
outer_ndim = var.ndim outer_ndim = var.ndim
else: else:
...@@ -990,8 +983,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -990,8 +983,8 @@ class PushOutScanOutput(gof.Optimizer):
len(old_scan_args.outer_out_shared)) len(old_scan_args.outer_out_shared))
new_node_old_outputs = ( new_node_old_outputs = (
new_scan_node.outputs[:new_node_new_outputs_idx] + new_scan_node.outputs[:new_node_new_outputs_idx] +
new_scan_node.outputs[new_node_new_outputs_idx+nb_new_outs:]) new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs:])
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
list(zip(old_scan_node.outputs, new_node_old_outputs)), list(zip(old_scan_node.outputs, new_node_old_outputs)),
...@@ -1017,7 +1010,7 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1017,7 +1010,7 @@ class ScanInplaceOptimizer(Optimizer):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler()) fgraph.attach_feature(DestroyHandler())
def attempt_scan_inplace(self, fgraph, node, output_indices): def attempt_scan_inplace(self, fgraph, node, output_indices, alloc_ops):
"""Attempts to replace a Scan node by one which computes the specified """Attempts to replace a Scan node by one which computes the specified
outputs inplace. outputs inplace.
...@@ -1029,6 +1022,10 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1029,6 +1022,10 @@ class ScanInplaceOptimizer(Optimizer):
Scan node to replace by an inplace version Scan node to replace by an inplace version
output_indices : list of integers output_indices : list of integers
Indices of the outputs to attempt to compute inplace Indices of the outputs to attempt to compute inplace
alloc_ops : list of Op classes
Classes that represent operation that allocate new memory and
that the optimization should duplicate so it can operate inplace
on them.
""" """
op = node.op op = node.op
...@@ -1049,6 +1046,14 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1049,6 +1046,14 @@ class ScanInplaceOptimizer(Optimizer):
ls_end += op.outer_nitsot(node.inputs) ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs) ls_end += op.outer_non_seqs(node.inputs)
# In `ls`, duplicate any input which has more then one client and is
# the output of an eligible allocation op
for i in range(len(ls)):
inp = ls[i]
if (len(inp.clients) > 1 and inp.owner and
isinstance(inp.owner.op, alloc_ops)):
ls[i] = inp.owner.op(*inp.owner.inputs)
n_outs = len(ls) n_outs = len(ls)
for idx in xrange(n_outs): for idx in xrange(n_outs):
if ls[idx] in ls[:idx]: if ls[idx] in ls[:idx]:
...@@ -1079,6 +1084,21 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1079,6 +1084,21 @@ class ScanInplaceOptimizer(Optimizer):
def apply(self, fgraph): def apply(self, fgraph):
# Depending on the values of gpu_flag and gpua_flag, get the list of
# memory allocation ops that the optimization should be able to handle
alloc_ops = (Alloc, AllocEmpty)
if self.gpu_flag:
alloc_ops += (theano.sandbox.cuda.GpuAlloc,
theano.sandbox.cuda.GpuAllocEmpty)
if self.gpua_flag:
# gpuarray might be imported but not its GpuAlloc and
# GpuAllopEmpty ops.
try:
alloc_ops += (theano.sandbox.gpuarray.GpuAlloc,
theano.sandbox.gpuarray.GpuAllocEmpty)
except:
pass
nodes = fgraph.toposort()[::-1] nodes = fgraph.toposort()[::-1]
scan_nodes = [x for x in nodes scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and if (isinstance(x.op, scan_op.Scan) and
...@@ -1101,7 +1121,18 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1101,7 +1121,18 @@ class ScanInplaceOptimizer(Optimizer):
out_indices = [] out_indices = []
for out_idx in range(n_outs): for out_idx in range(n_outs):
inp_idx = 1 + op.n_seqs + out_idx inp_idx = 1 + op.n_seqs + out_idx
inp = original_node.inputs[inp_idx]
# If the input is from an eligible allocation node, attempt to
# be inplace on it, even if other nodes are modifying it
# inplace.
if inp.owner and isinstance(inp.owner.op, alloc_ops):
out_indices.append(out_idx)
continue
# If the input is not from an eligible allocation node, only
# attempt to be inplace on it if nothing else is currently
# inplace on it.
input_used_inplace = False input_used_inplace = False
for c in original_node.inputs[inp_idx].clients: for c in original_node.inputs[inp_idx].clients:
client = c[0] client = c[0]
...@@ -1122,14 +1153,15 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1122,14 +1153,15 @@ class ScanInplaceOptimizer(Optimizer):
out_indices.append(out_idx) out_indices.append(out_idx)
node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx], node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx],
out_indices) out_indices, alloc_ops)
if node is original_node: if node is original_node:
# Making the scan compute all plausible recurrent outputs # Making the scan compute all plausible recurrent outputs
# inplace has failed. Attempt all plausible recurrent output # inplace has failed. Attempt all plausible recurrent output
# individually. # individually.
for pos in out_indices: for pos in out_indices:
node = self.attempt_scan_inplace(fgraph, node, [pos]) node = self.attempt_scan_inplace(fgraph, node, [pos],
alloc_ops)
class ScanSaveMem(gof.Optimizer): class ScanSaveMem(gof.Optimizer):
...@@ -1242,7 +1274,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1242,7 +1274,7 @@ class ScanSaveMem(gof.Optimizer):
for cl, _ in out.clients: for cl, _ in out.clients:
# 2.1 outputs of the function # 2.1 outputs of the function
#=> output needs all its intermediate values # => output needs all its intermediate values
if type(cl) == str: if type(cl) == str:
# if the node is actually an output, then # if the node is actually an output, then
# we need to store the entire thing # we need to store the entire thing
...@@ -1250,20 +1282,20 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1250,20 +1282,20 @@ class ScanSaveMem(gof.Optimizer):
slices[i] = None slices[i] = None
break break
# 2.2 non-subtensor nodes # 2.2 non-subtensor nodes
#=> output needs all its intermediate values # => output needs all its intermediate values
elif not isinstance(cl.op, tensor.Subtensor): elif not isinstance(cl.op, tensor.Subtensor):
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
# 2.3 subtensor nodes # 2.3 subtensor nodes
#=> output might need to store just a subset of its values # => output might need to store just a subset of its values
else: else:
# 2.3.1 extract idx list of subtensor # 2.3.1 extract idx list of subtensor
this_slice = tensor.get_idx_list(cl.inputs, this_slice = tensor.get_idx_list(cl.inputs,
cl.op.idx_list) cl.op.idx_list)
if this_slice is None: if this_slice is None:
# if unable to extract idx_list # if unable to extract idx_list
#=> outputs needs all its intermediate values # => outputs needs all its intermediate values
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
...@@ -1284,7 +1316,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1284,7 +1316,7 @@ class ScanSaveMem(gof.Optimizer):
slices[i] += [(cf_slice, this_slice)] slices[i] += [(cf_slice, this_slice)]
if (isinstance(this_slice[0], slice) and if (isinstance(this_slice[0], slice) and
this_slice[0].stop is None): this_slice[0].stop is None):
global_nsteps = None global_nsteps = None
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
stop = tensor.basic.extract_constant(cf_slice[0].stop) stop = tensor.basic.extract_constant(cf_slice[0].stop)
...@@ -1374,7 +1406,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1374,7 +1406,7 @@ class ScanSaveMem(gof.Optimizer):
break break
if (isinstance(this_slice[0], slice) and if (isinstance(this_slice[0], slice) and
this_slice[0].start is None): this_slice[0].start is None):
store_steps[i] = 0 store_steps[i] = 0
break break
...@@ -1406,8 +1438,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1406,8 +1438,7 @@ class ScanSaveMem(gof.Optimizer):
# for mitsots and sitsots (because mitmots are not # for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if # currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated. # the pre-allocation mechanism is activated.
prealloc_outs = \ prealloc_outs = theano.config.scan.allow_output_prealloc
theano.config.scan.allow_output_prealloc
first_mitsot_idx = node.op.n_mit_mot first_mitsot_idx = node.op.n_mit_mot
last_sitsot_idx = (node.op.n_mit_mot + last_sitsot_idx = (node.op.n_mit_mot +
...@@ -1433,7 +1464,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1433,7 +1464,7 @@ class ScanSaveMem(gof.Optimizer):
# currently. # currently.
# pval = pre_greedy_local_optimizer(list_opt_slice, # pval = pre_greedy_local_optimizer(list_opt_slice,
# pval) # pval)
#pval = pre_constant_merge([pval])[0] # pval = pre_constant_merge([pval])[0]
# if (isinstance(pval, theano.tensor.TensorConstant) # if (isinstance(pval, theano.tensor.TensorConstant)
# and # and
# pval.dtype.startswith('int')): # pval.dtype.startswith('int')):
...@@ -1554,7 +1585,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1554,7 +1585,7 @@ class ScanSaveMem(gof.Optimizer):
nw_steps) nw_steps)
nw_inputs[in_idx] = nw_input nw_inputs[in_idx] = nw_input
else: else:
nw_input = nw_inputs[in_idx][:(initl+nw_steps)] nw_input = nw_inputs[in_idx][:(initl + nw_steps)]
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
in_idx = offset + idx + op.n_shared_outs in_idx = offset + idx + op.n_shared_outs
...@@ -1563,7 +1594,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1563,7 +1594,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.5 Remove unwanted orphane outputs # 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \ (inps, outs, info, node_ins, compress_map) = \
scan_utils.compress_outs(op, not_required, nw_inputs) scan_utils.compress_outs(op, not_required, nw_inputs)
inv_compress_map = OrderedDict() inv_compress_map = OrderedDict()
for k, v in iteritems(compress_map): for k, v in iteritems(compress_map):
inv_compress_map[v] = k inv_compress_map[v] = k
...@@ -1633,15 +1664,15 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1633,15 +1664,15 @@ class ScanSaveMem(gof.Optimizer):
start = (cnf_slice[0].start - nw_steps - start = (cnf_slice[0].start - nw_steps -
init_l[pos] + store_steps[pos]) init_l[pos] + store_steps[pos])
if (cnf_slice[0].stop is not None and if (cnf_slice[0].stop is not None and
cnf_slice[0].stop != maxsize): cnf_slice[0].stop != maxsize):
stop = (cnf_slice[0].stop - nw_steps - stop = (cnf_slice[0].stop - nw_steps -
init_l[pos] + store_steps[pos]) init_l[pos] + store_steps[pos])
else: else:
stop = None stop = None
nw_slice = ((slice(sanitize(start), nw_slice = ((slice(sanitize(start),
sanitize(stop), sanitize(stop),
sanitize(cnf_slice[0].step)),) sanitize(cnf_slice[0].step)),) +
+ tuple(old_slices[1:])) tuple(old_slices[1:]))
else: else:
position = (cnf_slice[0] - nw_steps - position = (cnf_slice[0] - nw_steps -
...@@ -1662,8 +1693,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1662,8 +1693,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.9. Get replace pairs for all other nodes # 3.9. Get replace pairs for all other nodes
if flag_store or global_nsteps is not None: if flag_store or global_nsteps is not None:
for idx, o in enumerate(node.outputs): for idx, o in enumerate(node.outputs):
if not (idx in replaced_outs) and \ if not (idx in replaced_outs) and idx not in not_required:
not idx in not_required:
nw_pos = compress_map[idx] nw_pos = compress_map[idx]
old_new += [(o, new_outs[nw_pos])] old_new += [(o, new_outs[nw_pos])]
# Check if the new outputs depend on the old scan node # Check if the new outputs depend on the old scan node
...@@ -1808,7 +1838,7 @@ class ScanMerge(gof.Optimizer): ...@@ -1808,7 +1838,7 @@ class ScanMerge(gof.Optimizer):
flat_inner_outs = sum(inner_outs[idx], []) flat_inner_outs = sum(inner_outs[idx], [])
# clone # clone
flat_inner_ins, flat_inner_outs = scan_utils.reconstruct_graph( flat_inner_ins, flat_inner_outs = scan_utils.reconstruct_graph(
flat_inner_ins, flat_inner_outs) flat_inner_ins, flat_inner_outs)
# split the new inner variables again in seq, mitmot, etc. # split the new inner variables again in seq, mitmot, etc.
new_inner_ins = [] new_inner_ins = []
count = 0 count = 0
...@@ -2072,8 +2102,8 @@ def scan_merge_inouts(node): ...@@ -2072,8 +2102,8 @@ def scan_merge_inouts(node):
# because they could have different sizes, and the corresponding # because they could have different sizes, and the corresponding
# outer outputs cannot be merged in that case. # outer outputs cannot be merged in that case.
for s_outer_i, s_inner_o, s_outer_o in seen: for s_outer_i, s_inner_o, s_outer_o in seen:
if (equal_computations([inner_o], [s_inner_o], left, right) if (equal_computations([inner_o], [s_inner_o], left, right) and
and outer_i == s_outer_i): outer_i == s_outer_i):
return s_outer_o return s_outer_o
seen.append((outer_i, inner_o, outer_o)) seen.append((outer_i, inner_o, outer_o))
return outer_o return outer_o
...@@ -2116,9 +2146,10 @@ def scan_merge_inouts(node): ...@@ -2116,9 +2146,10 @@ def scan_merge_inouts(node):
na.outer_out_mit_mot, na.outer_out_mit_mot,
na.mit_mot_out_slices): na.mit_mot_out_slices):
for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen: for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen:
if (osl == sosl if (osl == sosl and
and equal_computations(inner_omm, s_inner_omm, left, right) equal_computations(inner_omm, s_inner_omm, left, right) and
and outer_imm == s_outer_imm): outer_imm == s_outer_imm):
new_outer_out_mit_mot.append(s_outer_omm) new_outer_out_mit_mot.append(s_outer_omm)
break break
else: else:
...@@ -2168,17 +2199,15 @@ class PushOutDot1(gof.Optimizer): ...@@ -2168,17 +2199,15 @@ class PushOutDot1(gof.Optimizer):
inp in out.owner.inputs and inp in out.owner.inputs and
len(outer_out.clients) == 1 and len(outer_out.clients) == 1 and
not isinstance(outer_out.clients[0][0], str) and not isinstance(outer_out.clients[0][0], str) and
isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor) isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor) and
and outer_out.clients[0][0].op.idx_list == (-1,)): outer_out.clients[0][0].op.idx_list == (-1,)):
x = out.owner.inputs[0] x = out.owner.inputs[0]
if x == inp: if x == inp:
x = out.owner.inputs[1] x = out.owner.inputs[1]
# We need to check if x is the result of an outer product # We need to check if x is the result of an outer product
if (x.owner and if (x.owner and isinstance(x.owner.op, theano.tensor.Dot) and
isinstance(x.owner.op, theano.tensor.Dot) and x.owner.inputs[0].ndim == 2 and x.owner.inputs[1].ndim == 2):
x.owner.inputs[0].ndim == 2 and
x.owner.inputs[1].ndim == 2):
# We need to check if any of the inputs are a sequence # We need to check if any of the inputs are a sequence
inp1 = x.owner.inputs[0] inp1 = x.owner.inputs[0]
...@@ -2219,18 +2248,17 @@ class PushOutDot1(gof.Optimizer): ...@@ -2219,18 +2248,17 @@ class PushOutDot1(gof.Optimizer):
new_info = op.info.copy() new_info = op.info.copy()
st = len(op.mitmot_taps()) + len(op.mitsot_taps()) st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info['tap_array'] = (\ new_info['tap_array'] = (
new_info['tap_array'][:st + idx] + new_info['tap_array'][:st + idx] +
new_info['tap_array'][st + new_info['tap_array'][st + idx + 1:])
idx + 1:])
new_info['n_sit_sot'] -= 1 new_info['n_sit_sot'] -= 1
new_info['n_nit_sot'] += 1 new_info['n_nit_sot'] += 1
inner_sitsot = inner_sitsot[:idx] + \ inner_sitsot = (inner_sitsot[:idx] +
inner_sitsot[idx + 1:] inner_sitsot[idx + 1:])
outer_sitsot = outer_sitsot[:idx] + \ outer_sitsot = (outer_sitsot[:idx] +
outer_sitsot[idx + 1:] outer_sitsot[idx + 1:])
inner_sitsot_outs = inner_sitsot_outs[:idx] +\ inner_sitsot_outs = (inner_sitsot_outs[:idx] +
inner_sitsot_outs[idx + 1:] inner_sitsot_outs[idx + 1:])
# add n_steps as the length # add n_steps as the length
inner_nitsot_outs.append(new_scan_out) inner_nitsot_outs.append(new_scan_out)
...@@ -2246,8 +2274,8 @@ class PushOutDot1(gof.Optimizer): ...@@ -2246,8 +2274,8 @@ class PushOutDot1(gof.Optimizer):
inner_nitsot_outs + inner_nitsot_outs +
inner_shared_outs) inner_shared_outs)
new_inner_inps, new_inner_outs =\ new_inner_inps, new_inner_outs =\
scan_utils.reconstruct_graph( scan_utils.reconstruct_graph(_new_inner_inps,
_new_inner_inps, _new_inner_outs) _new_inner_outs)
new_op = scan_op.Scan(new_inner_inps, new_inner_outs, new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
new_info) new_info)
_scan_inputs = ([node.inputs[0]] + _scan_inputs = ([node.inputs[0]] +
...@@ -2267,11 +2295,7 @@ class PushOutDot1(gof.Optimizer): ...@@ -2267,11 +2295,7 @@ class PushOutDot1(gof.Optimizer):
# We need now to pair correctly the new outputs # We need now to pair correctly the new outputs
# with the old ones # with the old ones
outer_mitmot_outs = new_op.outer_mitmot_outs(new_outs)
outer_mitsot_outs = new_op.outer_mitsot_outs(new_outs)
outer_sitsot_outs = new_op.outer_sitsot_outs(new_outs)
outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs) outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs)
outer_shared_outs = new_op.outer_shared_outs(new_outs)
_val = outer_nitsot_outs[-1] _val = outer_nitsot_outs[-1]
outer_nitsot_outs = outer_nitsot_outs[:-1] outer_nitsot_outs = outer_nitsot_outs[:-1]
...@@ -2305,7 +2329,7 @@ class PushOutDot1(gof.Optimizer): ...@@ -2305,7 +2329,7 @@ class PushOutDot1(gof.Optimizer):
old_new = list(zip(node.outputs[:pos], new_outs[:pos])) old_new = list(zip(node.outputs[:pos], new_outs[:pos]))
old = node.outputs[pos].clients[0][0].outputs[0] old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out)) old_new.append((old, new_out))
old_new += list(zip(node.outputs[pos+1:], old_new += list(zip(node.outputs[pos + 1:],
new_outs[pos:])) new_outs[pos:]))
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
old_new, remove=[node], reason='scan_pushout_dot1') old_new, remove=[node], reason='scan_pushout_dot1')
...@@ -2374,61 +2398,61 @@ scan_seqopt1.register('scanOp_pushout_output', ...@@ -2374,61 +2398,61 @@ scan_seqopt1.register('scanOp_pushout_output',
scan_eqopt2.register('constant_folding_for_scan2', scan_eqopt2.register('constant_folding_for_scan2',
opt.in2out(tensor.opt.constant_folding, opt.in2out(tensor.opt.constant_folding,
ignore_newtrees=True), ignore_newtrees=True),
1, 1,
'fast_run', 'fast_run',
'scan') 'scan')
scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs1', scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs1',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True), ignore_newtrees=True),
2, 2,
'remove_constants_and_unused_inputs_scan', 'remove_constants_and_unused_inputs_scan',
'fast_run', 'fast_run',
'scan') 'scan')
# after const merge but before stabilize so that we can have identity # after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out # for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later. # of the scan later.
scan_eqopt2.register('scanOp_merge', scan_eqopt2.register('scanOp_merge',
ScanMerge(), ScanMerge(),
4, 4,
'fast_run', 'fast_run',
'scan') 'scan')
# After Merge optimization # After Merge optimization
scan_eqopt2.register('scanop_remove_constants_and_unused_inputs2', scan_eqopt2.register('scanop_remove_constants_and_unused_inputs2',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True), ignore_newtrees=True),
5, 5,
'remove_constants_and_unused_inputs_scan', 'remove_constants_and_unused_inputs_scan',
'fast_run', 'fast_run',
'scan') 'scan')
scan_eqopt2.register('scanOp_merge_inouts', scan_eqopt2.register('scanOp_merge_inouts',
opt.in2out(scan_merge_inouts, ignore_newtrees=True), opt.in2out(scan_merge_inouts, ignore_newtrees=True),
6, 6,
'scan_merge_inouts', 'scan_merge_inouts',
'fast_run', 'fast_run',
'scan') 'scan')
# Just before specialize to have the other optimization # Just before specialize to have the other optimization
# like constant folding being applied # like constant folding being applied
# This don't introduce inplace. # This don't introduce inplace.
scan_eqopt2.register('scanOp_save_mem', scan_eqopt2.register('scanOp_save_mem',
ScanSaveMem(), ScanSaveMem(),
7, 7,
'fast_run', 'fast_run',
'scan') 'scan')
# After everything else # After everything else
scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs3', scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs3',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True), ignore_newtrees=True),
8, 8,
'remove_constants_and_unused_inputs_scan', 'remove_constants_and_unused_inputs_scan',
'fast_run', 'fast_run',
'scan') 'scan')
...@@ -164,7 +164,6 @@ whitelist_flake8 = [ ...@@ -164,7 +164,6 @@ whitelist_flake8 = [
"scan_module/scan_op.py", "scan_module/scan_op.py",
"scan_module/scan_perform_ext.py", "scan_module/scan_perform_ext.py",
"scan_module/__init__.py", "scan_module/__init__.py",
"scan_module/scan_opt.py",
"scan_module/tests/test_scan.py", "scan_module/tests/test_scan.py",
"scan_module/tests/test_scan_opt.py", "scan_module/tests/test_scan_opt.py",
"misc/tests/test_may_share_memory.py", "misc/tests/test_may_share_memory.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论