提交 ed415454 authored 作者: Razvan Pascanu's avatar Razvan Pascanu 提交者: Razvan Pascanu

More sane inplace optimization

The optimization before was all or nothing, which is a bad compromise for scan. This new optimization tries to work in place on each outputs, and keeps only those for which it can. This way you do not get your entire op to be non-inplace because of a single output that can not be computed inplace.
上级 f57b7b77
import logging import logging
_logger = logging.getLogger('theano.sandbox.cuda.opt') _logger = logging.getLogger('theano.sandbox.cuda.opt')
import copy
import sys import sys
import warnings import warnings
import numpy import numpy
import theano import theano
from theano.scan_module import scan_utils, scan_op, scan_opt
from theano import scalar as scal from theano import scalar as scal
from theano import tensor, compile, gof from theano import tensor, compile, gof
from theano.compile import optdb from theano.compile import optdb
from theano.gof import (local_optimizer, EquilibriumDB, SequenceDB, ProxyDB, from theano.gof import (local_optimizer, EquilibriumDB, SequenceDB, ProxyDB,
Optimizer, toolbox, DestroyHandler, Optimizer, toolbox, DestroyHandler,
EquilibriumOptimizer) InconsistencyError, EquilibriumOptimizer)
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
from theano.sandbox.cuda.basic_ops import * from theano.sandbox.cuda.basic_ops import *
from theano.sandbox.cuda.type import CudaNdarrayType from theano.sandbox.cuda.type import CudaNdarrayType
...@@ -1431,7 +1433,7 @@ def gpuScanOptimization(node): ...@@ -1431,7 +1433,7 @@ def gpuScanOptimization(node):
# merged or implement this optimization as a global # merged or implement this optimization as a global
# optimization # optimization
thescan = host_input.owner.op thescan = host_input.owner.op
info = thescan.info.copy() info = copy.deepcopy(thescan.info)
info['gpu'] = True info['gpu'] = True
inputs = host_input.owner.inputs inputs = host_input.owner.inputs
nw_ins = [inputs[0]] nw_ins = [inputs[0]]
...@@ -1478,7 +1480,7 @@ def gpuScanOptimization(node): ...@@ -1478,7 +1480,7 @@ def gpuScanOptimization(node):
for i in node.inputs]): for i in node.inputs]):
thescan = node.op thescan = node.op
info = thescan.info.copy() info = copy.deepcopy(thescan.info)
info['gpu'] = True info['gpu'] = True
inputs = node.inputs inputs = node.inputs
nw_ins = [inputs[0]] nw_ins = [inputs[0]]
...@@ -1527,41 +1529,9 @@ def gpuScanOptimization(node): ...@@ -1527,41 +1529,9 @@ def gpuScanOptimization(node):
return False return False
@gof.local_optimizer([None])
def gpu_scan_make_inplace(node):
op = node.op
if (isinstance(op, scan_op.Scan) and
(not op.info['inplace']) and
(op.info['gpu'])):
info = op.info.copy()
info['inplace'] = True
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node)
ls += op.outer_mitsot(node)
ls += op.outer_sitsot(node)
ls_end = op.outer_shared(node)
ls_end += op.outer_nitsot(node)
ls_end += op.outer_non_seqs(node)
n_outs = len(ls)
for idx in xrange(n_outs):
if ls[idx] in ls[:idx]:
ls[idx] = compile.function_module.deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end
typeConstructor = lambda broadcastable, dtype: CudaNdarrayType(
broadcastable=broadcastable)
new_op = scan_op.Scan(op.inputs,
op.outputs,
info,
typeConstructor=typeConstructor)
return new_op.make_node(*inputs).outputs
return False
optdb.register('gpu_scanOp_make_inplace', optdb.register('gpu_scanOp_make_inplace',
theano.tensor.opt.in2out( scan_opt.ScanInplaceOptimizer(typeConstructor=CudaNdarrayType,
gpu_scan_make_inplace, ignore_newtrees=True), gpu_flag=True),
75, 75,
'gpu', 'gpu',
'fast_run', 'fast_run',
......
"""
This module provides a different interface for the Scan Op.
This is a sligthly more advanced interface that helps avoiding certain
issues that scan can cause.
"""
__docformat__ = 'restructedtext en'
__authors__ = "Razvan Pascanu "
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import itertools
import logging
import numpy
from theano.compile import SharedVariable, function
from theano import compile
from theano import gof
from theano.tensor import opt
from theano import tensor
from theano import config
from theano.updates import Updates
from theano.scan_module import scan_op
from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import safe_new, traverse
# Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan')
def scan(fn,
sequences=None,
states=None,
params=None,
n_steps=None,
truncate_gradient=-1,
mode=None,
name=None,
profile=False):
"""
WRITE ME !
"""
# General observation : this code is executed only once, at creation
# of the computational graph, so we don't yet need to be smart about
# anything (to speed things up)
##
### Step 1. Wrap all inputs in dictionaries and add default values
##
# check if inputs are just single variables instead of lists
def wrap_into_list(x):
'''
Wrap the input into a list if it is not already a list
'''
if x is None:
return []
elif not isinstance(x, (list, tuple)):
return [x]
else:
return list(x)
seqs = wrap_into_list(sequences)
outs_info = wrap_into_list(states)
# Make sure we get rid of numpy arrays or ints or anything like that
# passed as inputs to scan
non_seqs = []
for elem in wrap_into_list(params):
if not isinstance(elem, gof.Variable):
non_seqs.append(tensor.as_tensor_variable(elem))
else:
non_seqs.append(elem)
# If we provided a known number of steps ( before compilation)
# and if that number is 1 or -1, then we can skip the Scan Op,
# and just apply the inner function once
# To do that we check here to see the nature of n_steps
n_fixed_steps = None
if isinstance(n_steps, (float, int)):
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = opt.get_constant_value(n_steps)
except (TypeError, AttributeError):
n_fixed_steps = None
# Check n_steps is an int
if (hasattr(n_steps, 'dtype') and
str(n_steps.dtype)[:3] not in ('uin', 'int')):
raise ValueError(' n_steps must be an int. dtype provided '
'is %s' % n_steps.dtype)
# compute number of sequences and number of outputs
n_seqs = len(seqs)
n_outs = len(outs_info)
return_steps = {}
# wrap sequences in a dictionary if they are not already dictionaries
for i in xrange(n_seqs):
if not isinstance(seqs[i], dict):
seqs[i] = dict(input=seqs[i], taps=[0])
elif seqs[i].get('taps', None):
seqs[i]['taps'] = wrap_into_list(seqs[i]['taps'])
elif seqs[i].get('taps', True) is None:
# seqs dictionary does not have the ``taps`` key
seqs[i]['taps'] = [0]
# wrap outputs info in a dictionary if they are not already in one
for i in xrange(n_outs):
if outs_info[i] is not None:
if isinstance(outs_info[i], dict):
# DEPRECATED :
if outs_info[i].get('return_steps', None):
raise ValueError(
"Using `return_steps` has been deprecated. "
"Simply select the entries you need using a "
"subtensor. Scan will optimize memory "
"consumption, so do not worry about that.")
# END
if not isinstance(outs_info[i], dict):
# by default any output has a tap value of -1
outs_info[i] = dict(initial=outs_info[i], taps=[-1])
elif (not outs_info[i].get('initial', None) and
outs_info[i].get('taps', None)):
# ^ no initial state but taps provided
raise ValueError(('If you are using slices of an output '
'you need to provide a initial state '
'for it'), outs_info[i])
elif (outs_info[i].get('initial', None) and
not outs_info[i].get('taps', None)):
# ^ initial state but taps not provided
if 'taps' in outs_info[i]:
# ^ explicitly provided a None for taps
_logger.warning('Output %s ( index %d) has a initial '
'state but taps is explicitly set to None ',
getattr(outs_info[i]['initial'], 'name', 'None'),
i)
outs_info[i]['taps'] = [-1]
else:
# if a None is provided as the output info we replace it
# with an empty dict() to simplify handling
outs_info[i] = dict()
##
### Step 2. Generate inputs and outputs of the inner functions
### for compiling a dummy function (Iteration #1)
##
# create theano inputs for the recursive function
# note : this is a first batch of possible inputs that will
# be compiled in a dummy function; we used this dummy
# function to detect shared variables and their updates
# and to construct a new and complete list of inputs and
# outputs
n_seqs = 0
scan_seqs = [] # Variables passed as inputs to the scan op
inner_seqs = [] # Variables passed as inputs to the inner function
inner_slices = [] # Actual slices if scan is removed from the picture
# go through sequences picking up time slices as needed
for i, seq in enumerate(seqs):
actual_slice = seq['input'][0]
_seq_val = tensor.as_tensor_variable(seq['input'])
_seq_val_slice = _seq_val[0]
nw_slice = _seq_val_slice.type()
if seq['input'].name:
nw_slice.name=seq['input'].name + '[t]'
scan_seqs.append(_seq_val)
inner_seqs.append(nw_slice)
inner_slices.append(actual_slice)
n_seqs += 1
actual_n_steps = tensor.as_tensor(n_steps)
# Add names -- it helps a lot when debugging
for (nw_seq, seq) in zip(scan_seqs, seqs):
if getattr(seq['input'], 'name', None) is not None:
nw_seq.name = seq['input'].name + '[%d:]' % k
# Conventions :
# mit_mot = multiple input taps, multiple output taps ( only provided
# by the gradient function )
# mit_sot = multiple input taps, single output tap (t + 0)
# sit_sot = single input tap, single output tap (t + 0)
# nit_sot = no input tap, single output tap (t + 0)
# MIT_MOT -- not provided by the user only by the grad function
n_mit_mot = 0
n_mit_mot_outs = 0
mit_mot_scan_inputs = []
mit_mot_inner_inputs = []
mit_mot_inner_outputs = []
mit_mot_out_slices = []
mit_mot_rightOrder = []
# SIT_SOT -- provided by the user
n_mit_sot = 0
mit_sot_scan_inputs = []
mit_sot_inner_inputs = []
mit_sot_inner_slices = []
mit_sot_inner_outputs = []
mit_sot_return_steps = {}
mit_sot_tap_array = []
mit_sot_rightOrder = []
n_sit_sot = 0
sit_sot_scan_inputs = []
sit_sot_inner_inputs = []
sit_sot_inner_slices = []
sit_sot_inner_outputs = []
sit_sot_return_steps = {}
sit_sot_rightOrder = []
nit_sot_steps = []
# go through outputs picking up time slices as needed
for i, init_out in enumerate(outs_info):
# Note that our convention dictates that if an output uses
# just the previous time step, as a initial state we will only
# provide a tensor of the same dimension as one time step; This
# makes code much cleaner for those who do not use taps. Otherwise
# they would always had to shape_padleft the initial state ..
# which is ugly
if init_out.get('taps', None) == [-1]:
actual_arg = init_out['initial']
arg = safe_new(init_out['initial'][0])
if isinstance(arg, tensor.Constant):
# safe new returns a clone of the constants, but that is not
# what we need for initial states
arg = arg.type()
# Try to transfer test_value to the new variable
if config.compute_test_value != 'off':
try:
arg.tag.test_value = gof.Op._get_test_value(actual_arg)
except AttributeError, e:
if config.compute_test_value != 'ignore':
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger.info(('Cannot compute test value for the '
'inner function of scan, input value missing %s'),
e)
if getattr(init_out['initial'], 'name', None) is not None:
arg.name = init_out['initial'].name + '[t-1]'
# We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function
# defined in scan utils
sit_sot_scan_inputs.append(actual_arg)
sit_sot_inner_slices.append(actual_arg[0])
if i in return_steps:
sit_sot_return_steps[n_sit_sot] = return_steps[i]
sit_sot_inner_inputs.append(arg)
sit_sot_rightOrder.append(i)
n_sit_sot += 1
elif init_out.get('taps', None):
if numpy.any(numpy.array(init_out.get('taps', [])) > 0):
# Make sure we do not have requests for future values of a
# sequence we can not provide such values
raise ValueError('Can not use future taps of outputs',
init_out)
# go through the taps
mintap = abs(numpy.min(init_out['taps']))
mit_sot_tap_array.append(init_out['taps'])
idx_offset = abs(numpy.min(init_out['taps']))
# Sequence
mit_sot_scan_inputs.append(init_out['initial'])
if i in return_steps:
mit_sot_return_steps[n_mit_sot] = return_steps[i]
mit_sot_rightOrder.append(i)
n_mit_sot += 1
for k in init_out['taps']:
# create a new slice
actual_nw_slice = init_out['initial'][k + mintap]
_init_out_var = tensor.as_tensor_variable(init_out['initial'])
_init_out_var_slice = _init_out_var[k + mintap]
nw_slice = _init_out_var_slice.type()
# Try to transfer test_value to the new variable
if config.compute_test_value != 'off':
try:
nw_slice.tag.test_value = gof.Op._get_test_value(
_init_out_var_slice)
except AttributeError, e:
if config.compute_test_value != 'ignore':
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger.info(('Cannot compute test value for '
'the inner function of scan, input value '
'missing. %s'), e)
# give it a name or debugging and pretty printing
if getattr(init_out['initial'], 'name', None) is not None:
if k > 0:
nw_slice.name = (init_out['initial'].name +
'[t+%d]' % k)
elif k == 0:
nw_slice.name = init_out['initial'].name + '[t]'
else:
nw_slice.name = (init_out['initial'].name +
'[t%d]' % k)
mit_sot_inner_inputs.append(nw_slice)
mit_sot_inner_slices.append(actual_nw_slice)
else:
nit_sot_steps.append(init_out['steps'])
# Re-order args
max_mit_sot = numpy.max([-1] + mit_sot_rightOrder) + 1
max_sit_sot = numpy.max([-1] + sit_sot_rightOrder) + 1
n_elems = numpy.max([max_mit_sot, max_sit_sot])
_ordered_args = [[] for x in xrange(n_elems)]
offset = 0
for idx in xrange(n_mit_sot):
n_inputs = len(mit_sot_tap_array[idx])
if n_fixed_steps in [1, -1]:
_ordered_args[mit_sot_rightOrder[idx]] = \
mit_sot_inner_slices[offset:offset + n_inputs]
else:
_ordered_args[mit_sot_rightOrder[idx]] = \
mit_sot_inner_inputs[offset:offset + n_inputs]
offset += n_inputs
for idx in xrange(n_sit_sot):
if n_fixed_steps in [1, -1]:
_ordered_args[sit_sot_rightOrder[idx]] = \
[sit_sot_inner_slices[idx]]
else:
_ordered_args[sit_sot_rightOrder[idx]] = \
[sit_sot_inner_inputs[idx]]
ordered_args = []
for ls in _ordered_args:
ordered_args += ls
if n_fixed_steps in [1, -1]:
args = (inner_slices +
ordered_args +
non_seqs)
else:
args = (inner_seqs +
ordered_args +
non_seqs)
# add only the non-shared variables and non-constants to the arguments of
# the dummy function [ a function should not get shared variables or
# constants as input ]
dummy_args = [arg for arg in args
if (not isinstance(arg, SharedVariable) and
not isinstance(arg, tensor.Constant))]
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
condition, outputs, updates = scan_utils.get_updates_and_outputs(fn(*args))
if condition is not None:
as_while = True
else:
as_while = False
##
### Step 3. Check if we actually need scan and remove it if we don't
##
if n_fixed_steps in [1, -1]:
# We do not need to use the scan op anymore, so we can just return
# the outputs and updates we have
if condition is not None:
_logger.warning(('When the number of steps is fixed and equal '
'to 1, the provided stopping condition, ',
str(condition), ' is ignored'))
for pos, inner_out in enumerate(outputs):
# we need to see if we need to pad our sequences with an
# unbroadcastable dimension; case example : we return an
# output for which we want all intermediate. If n_steps is 1
# then, if we return the output as given by the innner function
# this will represent only a slice and it will have one
# dimension less.
if (isinstance(inner_out.type, tensor.TensorType) and
return_steps.get(pos, 0) != 1):
outputs[pos] = tensor.unbroadcast(
tensor.shape_padleft(inner_out), 0)
if len(outputs) == 1:
outputs = outputs[0]
return (outputs, updates)
##
### Step 4. Compile the dummy function
##
# We can now compile a dummy function just to see what shared variable
# we have and what are their update rules (note that the user has
# the option not to pass the shared variable to scan, so we need to
# pick them manually and add them to scan)
# make the compilation as fast as possible by not applying any
# optimization or conversion to C [ note this region is not important
# for performance so we can do stuff as unoptimal as we wish ]
# extract still missing inputs (there still might be so) and add them
# as non sequences at the end of our args
fake_nonseqs = [x.type() for x in non_seqs]
fake_outputs = scan_utils.clone(outputs,
replace=dict(zip(non_seqs,
fake_nonseqs)))
all_inputs = itertools.ifilter(
lambda x: (isinstance(x, gof.Variable) and
not isinstance(x, SharedVariable) and
not isinstance(x, gof.Constant)),
gof.graph.inputs(fake_outputs))
extra_inputs = filter(lambda x: x not in args + fake_nonseqs,
all_inputs)
non_seqs += extra_inputs
## Note we do not use all_inputs directly since the order of variables
## in args is quite important
dummy_args += extra_inputs
dummy_outs = outputs
if condition is not None:
dummy_outs.append(condition)
dummy_f = function(dummy_args,
dummy_outs,
updates=updates,
mode=compile.mode.Mode(linker='py',
optimizer=None))
##
### Step 5. Re-arange inputs of scan into a more strict order
##
## Step 5.0 Check the outputs of the dummy function to see if they
## match with user provided data
# if the number of outputs to the function does not match the number of
# assumed outputs until now (provided by the user) there can be
# only one explanation: No information is provided for any of the
# outputs (i.e. we are dealing with a map)
tmp_dummy_f_outs = len(dummy_f.maker.outputs)
if as_while:
tmp_dummy_f_outs -= 1
if not (tmp_dummy_f_outs == n_outs or outs_info == []):
raise ValueError('Please provide None as output_info for '
'any output that does not feed back into '
'scan (i.e. it behaves like a map) ')
if outs_info == []:
n_outs = len(dummy_f.maker.outputs)
if as_while:
n_outs = n_outs - 1
outs_info = [dict() for x in xrange(n_outs)]
## Step 5.1 Outputs with taps different then -1
for i, out in enumerate(outs_info):
if 'taps' in out and out['taps'] != [-1]:
mit_sot_inner_outputs.append(outputs[i])
## Step 5.2 Outputs with tap equal to -1
for i, out in enumerate(outs_info):
if 'taps' in out and out['taps'] == [-1]:
sit_sot_inner_outputs.append(outputs[i])
## Step 5.3 Outputs that correspond to update rules of shared variables
givens = {}
n_shared_outs = 0
shared_scan_inputs = []
shared_inner_inputs = []
shared_inner_outputs = []
for input in dummy_f.maker.expanded_inputs:
if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable)
if getattr(input.variable, 'name', None) is not None:
new_var.name = input.variable.name + '_copy'
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
givens[input.variable] = new_var
n_shared_outs += 1
## Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0
nit_sot_inner_outputs = []
nit_sot_return_steps = {}
nit_sot_rightOrder = []
for i, out in enumerate(outs_info):
if not 'taps' in out:
nit_sot_inner_outputs.append(outputs[i])
if i in return_steps:
nit_sot_return_steps[n_nit_sot] = return_steps[i]
nit_sot_rightOrder.append(i)
n_nit_sot += 1
## Step 5.5 all other arguments including extra inputs
other_scan_args = []
other_inner_args = []
other_scan_args += [arg for arg in non_seqs
if (not isinstance(arg, SharedVariable) and
not isinstance(arg, tensor.Constant))]
## Step 5.6 all shared variables with no update rules
other_inner_args += [safe_new(arg, '_copy') for arg in non_seqs
if (not isinstance(arg, SharedVariable) and
not isinstance(arg, tensor.Constant))]
givens.update(dict(zip(other_scan_args, other_inner_args)))
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
givens.update(dict(zip(other_shared_scan_args,
other_shared_inner_args)))
##
### Step 6. Re-order the outputs and clone them replacing things
### using the givens
##
inner_inputs = (inner_seqs +
mit_mot_inner_inputs +
mit_sot_inner_inputs +
sit_sot_inner_inputs +
shared_inner_inputs +
other_shared_inner_args +
other_inner_args)
inner_outs = (mit_mot_inner_outputs +
mit_sot_inner_outputs +
sit_sot_inner_outputs +
nit_sot_inner_outputs +
shared_inner_outputs)
if condition is not None:
inner_outs.append(condition)
# Cuda is imported here, instead of being imported on top of the file
# because forces on the user some dependencies that we might do not want
# to. Currently we are working on removing the dependencies on sandbox
# code completeley.
from theano.sandbox import cuda
if cuda.cuda_available:
# very often we end up in this situation when we want to
# replace w with w_copy, where w is CudaNdarray
# and w_copy is TensorType. This is caused because shared
# variables are put on GPU right aways >:| ,
new_givens = {}
for w, w_copy in givens.iteritems():
if (isinstance(w.type, cuda.CudaNdarrayType)
and isinstance(w_copy.type, tensor.TensorType)):
for o in inner_outs:
new_givens = traverse(o, w, w_copy, new_givens)
else:
new_givens[w] = w_copy
else:
new_givens = givens
new_outs = scan_utils.clone(inner_outs, replace=new_givens)
##
### Step 7. Create the Scan Op
##
tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)]
info = {}
info['tap_array'] = tap_array
info['n_seqs'] = n_seqs
info['n_mit_mot'] = n_mit_mot
info['n_mit_mot_outs'] = n_mit_mot_outs
info['mit_mot_out_slices'] = mit_mot_out_slices
info['n_mit_sot'] = n_mit_sot
info['n_sit_sot'] = n_sit_sot
info['n_shared_outs'] = n_shared_outs
info['n_nit_sot'] = n_nit_sot
info['truncate_gradient'] = truncate_gradient
info['name'] = name
info['mode'] = mode
info['inplace'] = False
info['gpu'] = False
info['as_while'] = as_while
info['profile'] = profile
info['_scan_merge_visited'] = True
local_op = scan_op.Scan(inner_inputs, new_outs, info)
##
### Step 8. Compute the outputs using the scan op
##
_scan_inputs = (scan_seqs +
mit_mot_scan_inputs +
mit_sot_scan_inputs +
sit_sot_scan_inputs +
shared_scan_inputs +
nit_sot_steps +
other_shared_scan_args +
other_scan_args)
scan_inputs = []
for arg in [actual_n_steps] + _scan_inputs:
try:
arg = tensor.as_tensor_variable(arg)
except TypeError:
# This happens for Random States for e.g. but it is a good way
# to make sure no input is a cuda ndarrays
pass
scan_inputs += [arg]
scan_outs = local_op(*scan_inputs)
if type(scan_outs) not in (list, tuple):
scan_outs = [scan_outs]
##
### Step 9. Figure out which outs are update rules for shared variables
### and so on ...
##
update_map = Updates()
def remove_dimensions(outs, steps_return, offsets=None):
out_ls = []
for idx, out in enumerate(outs):
if idx in steps_return:
if steps_return[idx] > 1:
out_ls.append(out[-steps_return[idx]:])
else:
out_ls.append(out[-1])
else:
if offsets is None:
out_ls.append(out)
else:
out_ls.append(out[offsets[idx]:])
return out_ls
offset = n_mit_mot
offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array]
mit_sot_outs = scan_outs[offset:offset + n_mit_sot]
# mit_sot_outs = remove_dimensions(
# scan_outs[offset:offset + n_mit_sot],
# mit_sot_return_steps,
# offsets)
offset += n_mit_sot
offsets = [1 for x in xrange(n_sit_sot)]
sit_sot_outs = scan_outs[offset:offset + n_sit_sot]
# remove_dimensions(
# scan_outs[offset:offset + n_sit_sot],
# sit_sot_return_steps,
# offsets)
offset += n_sit_sot
nit_sot_outs = scan_outs[offset:offset + n_nit_sot]
#nit_sot_outs = remove_dimensions(
# scan_outs[offset:offset + n_nit_sot],
# nit_sot_return_steps)
offset += n_nit_sot
for idx, update_rule in enumerate(
scan_outs[offset:offset + n_shared_outs]):
update_map[shared_scan_inputs[idx]] = update_rule
_scan_out_list = (mit_sot_outs +
sit_sot_outs +
nit_sot_outs)
# Step 10. I need to reorder the outputs to be in the order expected by
# the user
rightOrder = (mit_sot_rightOrder +
sit_sot_rightOrder +
nit_sot_rightOrder)
scan_out_list = [None] * len(rightOrder)
for idx, pos in enumerate(rightOrder):
scan_out_list[pos] = _scan_out_list[idx]
if len(scan_out_list) == 1:
scan_out_list = scan_out_list[0]
elif len(scan_out_list) == 0:
scan_out_list = None
return (scan_out_list, update_map)
...@@ -949,7 +949,8 @@ def scan(fn, ...@@ -949,7 +949,8 @@ def scan(fn,
info['truncate_gradient'] = truncate_gradient info['truncate_gradient'] = truncate_gradient
info['name'] = name info['name'] = name
info['mode'] = mode info['mode'] = mode
info['inplace'] = False info['inplace'] = -1
info['destroy_map'] = {}
info['gpu'] = False info['gpu'] = False
info['as_while'] = as_while info['as_while'] = as_while
info['profile'] = profile info['profile'] = profile
......
...@@ -94,12 +94,6 @@ class Scan(PureOp): ...@@ -94,12 +94,6 @@ class Scan(PureOp):
if self.as_while: if self.as_while:
self.output_types = self.output_types[:-1] self.output_types = self.output_types[:-1]
self.destroy_map = {}
if hasattr(self, 'inplace') and self.inplace:
for idx in xrange(self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot):
self.destroy_map[idx] = [idx + 1 + self.n_seqs]
mode_instance = compile.mode.get_mode(self.mode) mode_instance = compile.mode.get_mode(self.mode)
# if the default mode is used, and that mode is ProfileMode # if the default mode is used, and that mode is ProfileMode
...@@ -411,12 +405,22 @@ class Scan(PureOp): ...@@ -411,12 +405,22 @@ class Scan(PureOp):
name = 'do_while' name = 'do_while'
else: else:
name = 'for' name = 'for'
aux_txt = '%s'
if self.inplace: if len(self.destroy_map.keys()) > 0:
aux_txt = '%s{inplace,%s,%s}' % (name, gpu_str, str(self.name)) # Check if all outputs are inplace
if (sorted(self.destroy_map.keys()) == \
sorted(range(self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot))):
aux_txt += 'all_inplace,%s,%s}'
else: else:
aux_txt = '%s{%s,%s}' % (name, gpu_str, str(self.name)) aux_txt += '{inplace{'
for k in self.destroy_map.keys():
aux_txt += str(k) + ','
aux_txt += '},%s,%s}'
else:
aux_txt +='{%s,%s}'
aux_txt = aux_txt % (name, gpu_str, str(self.name))
return aux_txt return aux_txt
def __hash__(self): def __hash__(self):
......
...@@ -13,6 +13,7 @@ __copyright__ = "(c) 2010, Universite de Montreal" ...@@ -13,6 +13,7 @@ __copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>" __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import logging import logging
import copy
import numpy import numpy
import theano import theano
...@@ -20,6 +21,8 @@ from theano import tensor ...@@ -20,6 +21,8 @@ from theano import tensor
from theano.tensor import opt, get_constant_value from theano.tensor import opt, get_constant_value
from theano import gof from theano import gof
from theano.gof.python25 import maxsize from theano.gof.python25 import maxsize
from theano.gof.opt import Optimizer
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.compile import optdb from theano.compile import optdb
from theano.compile.function_module import deep_copy_op from theano.compile.function_module import deep_copy_op
...@@ -117,7 +120,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -117,7 +120,7 @@ def remove_constants_and_unused_inputs_scan(node):
if len(nw_inner) != len(op_ins): if len(nw_inner) != len(op_ins):
op_outs = scan_utils.clone(op_outs, replace=givens) op_outs = scan_utils.clone(op_outs, replace=givens)
nw_info = op.info.copy() nw_info = copy.deepcopy(op.info)
nw_info['n_seqs'] = nw_n_seqs nw_info['n_seqs'] = nw_n_seqs
# DEBUG CHECK # DEBUG CHECK
nwScan = scan_op.Scan(nw_inner, op_outs, nw_info) nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
...@@ -304,14 +307,35 @@ scan_seqopt.register('scanOp_pushout_nonseqs_ops', ...@@ -304,14 +307,35 @@ scan_seqopt.register('scanOp_pushout_nonseqs_ops',
'scan') 'scan')
@gof.local_optimizer([None])
def scan_make_inplace(node): class ScanInplaceOptimizer(Optimizer):
"""Graph optimizer for Scan(makes it run inplace)"""
def __init__(self, typeConstructor=None, gpu_flag=False):
Optimizer.__init__(self)
self.typeConstructor = typeConstructor
self.gpu_flag = gpu_flag
def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate())
env.extend(DestroyHandler())
def apply(self, env):
nodes = env.toposort()
scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and
x.op.info['gpu']== self.gpu_flag)]
for scan_idx in xrange(len(scan_nodes)):
node = scan_nodes[scan_idx]
op = node.op op = node.op
if (isinstance(op, scan_op.Scan) and n_outs = (op.info['n_mit_mot'] +
(not op.info['inplace']) and op.info['n_mit_sot'] +
(not op.info['gpu'])): op.info['n_sit_sot'])
info = op.info.copy() for pos in xrange(n_outs):
info['inplace'] = True info = copy.deepcopy(op.info)
if not 'destroy_map' in info:
info['destroy_map'] = {}
info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']]
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs] ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node.inputs) ls = op.outer_mitmot(node.inputs)
...@@ -328,12 +352,23 @@ def scan_make_inplace(node): ...@@ -328,12 +352,23 @@ def scan_make_inplace(node):
inputs = ls_begin + ls + ls_end inputs = ls_begin + ls + ls_end
new_op = scan_op.Scan(op.inputs, new_op = scan_op.Scan(op.inputs,
op.outputs, op.outputs,
info) info,
return new_op.make_node(*inputs).outputs typeConstructor=self.typeConstructor)
return False
new_outs = new_op.make_node(*inputs).outputs
try:
env.replace_all_validate(
zip(node.outputs, new_outs),
reason=self.__class__.__name__)
op = new_op
node = new_outs[0].owner
except InconsistencyError, e:
# Failed moving output to be comptued inplace
pass
optdb.register('scanOp_make_inplace', optdb.register('scanOp_make_inplace',
opt.in2out(scan_make_inplace, ignore_newtrees=True), ScanInplaceOptimizer(typeConstructor=None,
gpu_flag=False),
75, 75,
'fast_run', 'fast_run',
'inplace', 'inplace',
......
...@@ -775,7 +775,10 @@ class T_Scan(unittest.TestCase): ...@@ -775,7 +775,10 @@ class T_Scan(unittest.TestCase):
updates=updates, updates=updates,
mode=mode, mode=mode,
allow_input_downcast=True) allow_input_downcast=True)
scan_node = [x for x in f9.maker.env.toposort()
if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert 0 in scan_node[0].op.destroy_map.keys()
assert 1 in scan_node[0].op.destroy_map.keys()
# compute output in numpy # compute output in numpy
numpy_x0 = numpy.zeros((3,)) numpy_x0 = numpy.zeros((3,))
numpy_x1 = numpy.zeros((3,)) numpy_x1 = numpy.zeros((3,))
...@@ -852,6 +855,10 @@ class T_Scan(unittest.TestCase): ...@@ -852,6 +855,10 @@ class T_Scan(unittest.TestCase):
mode=mode, mode=mode,
allow_input_downcast=True) allow_input_downcast=True)
scan_node = [x for x in f9.maker.env.toposort()
if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert 0 in scan_node[0].op.destroy_map.keys()
assert 1 in scan_node[0].op.destroy_map.keys()
# compute output in numpy # compute output in numpy
numpy_x0 = numpy.zeros((3,)) numpy_x0 = numpy.zeros((3,))
numpy_x1 = numpy.zeros((3,)) numpy_x1 = numpy.zeros((3,))
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论