提交 d9ab67db authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5270 from abergeron/scan_pushout_

Add a better error for scan pushout failure
......@@ -48,14 +48,6 @@ relies on the following elements to work properly :
"""
from __future__ import absolute_import, print_function, division
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import copy
import itertools
import logging
......@@ -63,11 +55,10 @@ import time
from collections import OrderedDict
import numpy
from six import iteritems, integer_types
from six import iteritems, integer_types, raise_from
from six.moves import xrange
import theano
from theano.compat import exc_message
from theano.compile import function, In, Out
from theano.compile.mode import AddFeatureOptimizer
from theano import compile, config, gradient, gof, tensor
......@@ -84,6 +75,14 @@ from theano.compile.profiling import ScanProfileStats
from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import safe_new, forced_replace
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
# Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan_op')
......@@ -140,8 +139,9 @@ class Scan(PureOp):
self.output_types = []
idx = 0
jdx = 0
tensorConstructor = lambda broadcastable, dtype: TensorType(
broadcastable=broadcastable, dtype=dtype)
def tensorConstructor(broadcastable, dtype):
return TensorType(broadcastable=broadcastable, dtype=dtype)
if typeConstructor is None:
typeConstructor = tensorConstructor
......@@ -209,7 +209,8 @@ class Scan(PureOp):
else:
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs)
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out, clone=False)
# This is actually required for the line just after.
gof.FunctionGraph(tmp_in, tmp_out, clone=False)
self._cmodule_key = gof.CLinker().cmodule_key_variables(self.inputs,
self.outputs,
[])
......@@ -447,8 +448,8 @@ class Scan(PureOp):
new_inputs.append(outer_mitmot)
for k in xrange(len(itaps)):
if (inner_mitmot[ipos + k].type.dtype !=
outer_mitmot.type.dtype or
inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1):
outer_mitmot.type.dtype or
inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_mitmot),
......@@ -487,9 +488,9 @@ class Scan(PureOp):
new_inputs.append(outer_mitsot)
for k in xrange(len(itaps)):
if (inner_mitsots[ipos + k].type.dtype != \
outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
if (inner_mitsots[ipos + k].type.dtype !=
outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_mitsot),
......@@ -587,9 +588,8 @@ class Scan(PureOp):
# need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int.
new_inputs += self.outer_nitsot(inputs)
for inner_nonseq, _outer_nonseq in zip(
self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)):
for inner_nonseq, _outer_nonseq in zip(self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)):
outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq)
if inner_nonseq.type != outer_nonseq.type:
......@@ -602,7 +602,7 @@ class Scan(PureOp):
# depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization
if (str(outer_nitsot.type.dtype)[:3] not in ('uin', 'int') or
outer_nitsot.ndim != 0):
outer_nitsot.ndim != 0):
raise ValueError('For output %s you need to provide a '
'scalar int !', str(outer_nitsot))
assert len(new_inputs) == len(inputs)
......@@ -630,9 +630,9 @@ class Scan(PureOp):
# Check if we are dealing with same type of objects
if not type(self) == type(other):
return False
if not 'destroy_map' in self.info:
if 'destroy_map' not in self.info:
self.info['destroy_map'] = OrderedDict()
if not 'destroy_map' in other.info:
if 'destroy_map' not in other.info:
other.info['destroy_map'] = OrderedDict()
keys_to_check = ['truncate_gradient', 'profile',
'n_seqs', 'tap_array',
......@@ -675,7 +675,7 @@ class Scan(PureOp):
self.destroy_map = OrderedDict()
if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace
if (sorted(self.destroy_map.keys()) == \
if (sorted(self.destroy_map.keys()) ==
sorted(range(self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot))):
......@@ -737,9 +737,6 @@ class Scan(PureOp):
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs]
node_output_compute = [compute_map[r] for r in node.outputs]
#_logger.debug('Compiling node %i of graph' % node_idx)
# If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of
# scan is done
......@@ -840,8 +837,8 @@ class Scan(PureOp):
profile = None
if (theano.config.profile or
(isinstance(self.profile, (string_types, bool, integer_types))
and self.profile)):
(isinstance(self.profile, (string_types, bool, integer_types)) and
self.profile)):
if isinstance(self.profile, string_types):
profile = ScanProfileStats(name=self.profile)
else:
......@@ -916,32 +913,33 @@ class Scan(PureOp):
cython_destroy_map = numpy.asarray(cython_destroy_map,
dtype='int32')
from . import scan_perform_ext
p = lambda node, args, outs:\
scan_perform_ext.perform(self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
self.n_mit_sot,
self.n_sit_sot,
self.n_nit_sot,
args[0],
self.as_while,
cython_mintaps,
cython_tap_array,
cython_tap_array_len,
cython_vector_seqs,
cython_vector_outs,
cython_mit_mot_out_slices,
cython_mit_mot_out_nslices,
cython_mitmots_preallocated,
cython_inps_is_tensor,
cython_outs_is_tensor,
self.fn.fn,
self.fn,
cython_destroy_map,
args,
outs,
self, node)
def p(node, args, outs):
return scan_perform_ext.perform(self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
self.n_mit_sot,
self.n_sit_sot,
self.n_nit_sot,
args[0],
self.as_while,
cython_mintaps,
cython_tap_array,
cython_tap_array_len,
cython_vector_seqs,
cython_vector_outs,
cython_mit_mot_out_slices,
cython_mit_mot_out_nslices,
cython_mitmots_preallocated,
cython_inps_is_tensor,
cython_outs_is_tensor,
self.fn.fn,
self.fn,
cython_destroy_map,
args,
outs,
self, node)
except (ImportError, theano.gof.cmodule.MissingGXX):
p = self.execute
# default arguments are stored in the closure of `rval`
......@@ -1183,8 +1181,8 @@ class Scan(PureOp):
outs[idx][0] = args[self.seqs_arg_offset + idx]
elif (outs[idx][0] is not None and
outs[idx][0].shape[1:] == args[self.seqs_arg_offset +
idx].shape[1:]
and outs[idx][0].shape[0] >= store_steps[idx]):
idx].shape[1:] and
outs[idx][0].shape[0] >= store_steps[idx]):
# Put in the values of the initial state
outs[idx][0] = outs[idx][0][:store_steps[idx]]
if idx > self.n_mit_mot:
......@@ -1212,7 +1210,7 @@ class Scan(PureOp):
i = 0
cond = True
############## THE MAIN LOOP #########################
# ############# THE MAIN LOOP ##############
# for i in xrange(n_steps):
while (i < n_steps) and cond:
# sequences over which scan iterates
......@@ -1263,7 +1261,7 @@ class Scan(PureOp):
for idx in xrange(self.n_outs + self.n_nit_sot -
self.n_mit_mot):
if (store_steps[idx + self.n_mit_mot] == 1 or
self.vector_outs[idx + self.n_mit_mot]):
self.vector_outs[idx + self.n_mit_mot]):
output_storage[idx + offset].storage[0] = None
else:
_pos0 = idx + self.n_mit_mot
......@@ -1434,8 +1432,17 @@ class Scan(PureOp):
output_reused = False
if not output_reused:
outs[j][0][pos[j]] = \
output_storage[offset_out + j].storage[0]
try:
outs[j][0][pos[j]] = \
output_storage[offset_out + j].storage[0]
except ValueError as e:
ne = ValueError(
"An output of the scan has changed shape. "
"This may be caused by a pushout optimization."
" Try adding "
"'optimizer_excluding=scanOp_pushout_output' "
"to your Theano flags.")
raise_from(ne, e)
# 5.5 Copy over the values for nit_sot outputs
begin = end
......@@ -1497,7 +1504,7 @@ class Scan(PureOp):
end = self.n_outs + self.n_nit_sot
for idx in xrange(begin, end):
if (store_steps[idx] < i - self.mintaps[idx] and
pos[idx] < store_steps[idx]):
pos[idx] < store_steps[idx]):
pdx = pos[idx]
if pdx >= store_steps[idx] // 2:
......@@ -1752,7 +1759,7 @@ class Scan(PureOp):
j_inp_idx = self.var_mappings["outer_inp_from_outer_out"][jidx]
if j_inp_idx != -1:
if connection_pattern[j_inp_idx][iidx] == True:
if connection_pattern[j_inp_idx][iidx] is True:
for k in xrange(len(connection_pattern)):
if connection_pattern[k][jidx]:
connection_pattern[k][iidx] = True
......@@ -2106,7 +2113,6 @@ class Scan(PureOp):
dc_dxts_idx += 1
dC_dinps_t = compute_all_gradients(known_grads)
# mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)):
if not dC_dinps_t[dx]:
......@@ -2159,7 +2165,6 @@ class Scan(PureOp):
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = numpy.min(self.tap_array[idx])
maxtap = numpy.max(self.tap_array[idx])
if idx < self.n_mit_mot:
outmaxtap = numpy.max(self.mitmot_out_taps()[idx])
else:
......@@ -2205,7 +2210,7 @@ class Scan(PureOp):
# Restrict the length of the outer sequences to the number of grad
# steps
outer_inp_seqs = [seq[:grad_steps] for seq in outer_inp_seqs]
outer_inp_seqs = [s_[:grad_steps] for s_ in outer_inp_seqs]
inner_inp_seqs = self.inner_seqs(self_inputs)
inner_inp_seqs += self.inner_mitmot(self_inputs)
......@@ -2215,7 +2220,6 @@ class Scan(PureOp):
inner_inp_seqs += Xts
# mitmot
outer_inp_mitmot = []
outer_out_mitmot = []
inner_inp_mitmot = []
inner_out_mitmot = []
mitmot_inp_taps = []
......@@ -2496,11 +2500,10 @@ class Scan(PureOp):
outer_inp_seqs +
outer_inp_mitmot +
outer_inp_sitsot +
[inputs[0] for x in xrange(n_nit_sot)] +
[inputs[0] for _ in xrange(n_nit_sot)] +
self.outer_shared(inputs) +
self.outer_non_seqs(inputs))
inner_other_args = self_inputs[offset:]
inner_gfn_ins = (inner_inp_seqs +
inner_inp_mitmot +
inner_inp_sitsot +
......@@ -2704,7 +2707,7 @@ class Scan(PureOp):
e = e + self.n_mit_sot
ib = ie
ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot: \
self.tap_array[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]]))
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
......@@ -2827,7 +2830,7 @@ gof.ops_with_inner_function[Scan] = 'fn'
# TODO: move that to the new back-end and new profiling.py print_tips
#@theano.compile.profilemode.register_profiler_printer
# @theano.compile.profilemode.register_profiler_printer
def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time, apply_cimpl, message, outputs_size,
other_time):
......@@ -2836,9 +2839,9 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time.items()]):
print()
print('Scan overhead:')
print ('<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(% scan op time)> <sub scan '
'op time(% scan op time)> <node>')
print('<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(% scan op time)> <sub scan '
'op time(% scan op time)> <node>')
total_super_scan_time = 0
total_scan_fct_time = 0
total_scan_op_time = 0
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -62,7 +62,7 @@ import copy
def get_version():
return 0.293
return 0.294
@cython.boundscheck(False)
def perform(
......@@ -544,7 +544,15 @@ def perform(
output_reused = False
if not output_reused:
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0]
try:
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0]
except ValueError as e:
raise ValueError(
"An output of the scan has changed shape. "
"This may be caused by a pushout optimization."
" Try adding "
"'optimizer_excluding=scanOp_pushout_output' "
"to your Theano flags.")
# 5.6 Copy over the values for outputs corresponding to shared
# variables
......
......@@ -17,7 +17,7 @@ from theano.gof import cmodule
_logger = logging.getLogger('theano.scan_module.scan_perform')
version = 0.293 # must match constant returned in function get_version()
version = 0.294 # must match constant returned in function get_version()
need_reload = False
......@@ -94,7 +94,7 @@ except ImportError:
# the old interface.
if False:
# During scan cython development, it is helpful to keep the old interface, to don't manually edit the c file each time.
preargs.remove('-D NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION')
preargs.remove('-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION')
else:
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
# Add add some macro to lower the number of edit
......@@ -102,13 +102,13 @@ except ImportError:
if bool(numpy_ver >= [1, 7]):
# Needed when we disable the old API, as cython
# use the old interface
preargs.append("-D NPY_ENSUREARRAY=NPY_ARRAY_ENSUREARRAY")
preargs.append("-D NPY_ENSURECOPY=NPY_ARRAY_ENSURECOPY")
preargs.append("-D NPY_ALIGNED=NPY_ARRAY_ALIGNED")
preargs.append("-D NPY_WRITEABLE=NPY_ARRAY_WRITEABLE")
preargs.append("-D NPY_UPDATE_ALL=NPY_ARRAY_UPDATE_ALL")
preargs.append("-D NPY_C_CONTIGUOUS=NPY_ARRAY_C_CONTIGUOUS")
preargs.append("-D NPY_F_CONTIGUOUS=NPY_ARRAY_F_CONTIGUOUS")
preargs.append("-DNPY_ENSUREARRAY=NPY_ARRAY_ENSUREARRAY")
preargs.append("-DNPY_ENSURECOPY=NPY_ARRAY_ENSURECOPY")
preargs.append("-DNPY_ALIGNED=NPY_ARRAY_ALIGNED")
preargs.append("-DNPY_WRITEABLE=NPY_ARRAY_WRITEABLE")
preargs.append("-DNPY_UPDATE_ALL=NPY_ARRAY_UPDATE_ALL")
preargs.append("-DNPY_C_CONTIGUOUS=NPY_ARRAY_C_CONTIGUOUS")
preargs.append("-DNPY_F_CONTIGUOUS=NPY_ARRAY_F_CONTIGUOUS")
cmodule.GCC_compiler.compile_str(dirname, code, location=loc,
preargs=preargs,
......
......@@ -213,7 +213,7 @@ class TestPushOutScanOutputDot(object):
# not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0]
# NOTE: WHEN INFER_SHAPE IS REENABLED, BELLOW THE SCAN MUST
# NOTE: WHEN INFER_SHAPE IS REENABLED, BELOW THE SCAN MUST
# HAVE ONLY 1 OUTPUT.
assert len(scan_node.op.outputs) == 2
assert not isinstance(scan_node.op.outputs[0], T.Dot)
......
......@@ -88,7 +88,6 @@ whitelist_flake8 = [
"scan_module/scan_utils.py",
"scan_module/scan_views.py",
"scan_module/scan.py",
"scan_module/scan_op.py",
"scan_module/scan_perform_ext.py",
"scan_module/__init__.py",
"scan_module/tests/__init__.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论