提交 d9ab67db authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5270 from abergeron/scan_pushout_

Add a better error for scan pushout failure
...@@ -48,14 +48,6 @@ relies on the following elements to work properly : ...@@ -48,14 +48,6 @@ relies on the following elements to work properly :
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import copy import copy
import itertools import itertools
import logging import logging
...@@ -63,11 +55,10 @@ import time ...@@ -63,11 +55,10 @@ import time
from collections import OrderedDict from collections import OrderedDict
import numpy import numpy
from six import iteritems, integer_types from six import iteritems, integer_types, raise_from
from six.moves import xrange from six.moves import xrange
import theano import theano
from theano.compat import exc_message
from theano.compile import function, In, Out from theano.compile import function, In, Out
from theano.compile.mode import AddFeatureOptimizer from theano.compile.mode import AddFeatureOptimizer
from theano import compile, config, gradient, gof, tensor from theano import compile, config, gradient, gof, tensor
...@@ -84,6 +75,14 @@ from theano.compile.profiling import ScanProfileStats ...@@ -84,6 +75,14 @@ from theano.compile.profiling import ScanProfileStats
from theano.scan_module import scan_utils from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import safe_new, forced_replace from theano.scan_module.scan_utils import safe_new, forced_replace
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
# Logging function for sending warning or info # Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan_op') _logger = logging.getLogger('theano.scan_module.scan_op')
...@@ -140,8 +139,9 @@ class Scan(PureOp): ...@@ -140,8 +139,9 @@ class Scan(PureOp):
self.output_types = [] self.output_types = []
idx = 0 idx = 0
jdx = 0 jdx = 0
tensorConstructor = lambda broadcastable, dtype: TensorType(
broadcastable=broadcastable, dtype=dtype) def tensorConstructor(broadcastable, dtype):
return TensorType(broadcastable=broadcastable, dtype=dtype)
if typeConstructor is None: if typeConstructor is None:
typeConstructor = tensorConstructor typeConstructor = tensorConstructor
...@@ -209,7 +209,8 @@ class Scan(PureOp): ...@@ -209,7 +209,8 @@ class Scan(PureOp):
else: else:
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs, tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs) self.outputs)
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out, clone=False) # This is actually required for the line just after.
gof.FunctionGraph(tmp_in, tmp_out, clone=False)
self._cmodule_key = gof.CLinker().cmodule_key_variables(self.inputs, self._cmodule_key = gof.CLinker().cmodule_key_variables(self.inputs,
self.outputs, self.outputs,
[]) [])
...@@ -447,8 +448,8 @@ class Scan(PureOp): ...@@ -447,8 +448,8 @@ class Scan(PureOp):
new_inputs.append(outer_mitmot) new_inputs.append(outer_mitmot)
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if (inner_mitmot[ipos + k].type.dtype != if (inner_mitmot[ipos + k].type.dtype !=
outer_mitmot.type.dtype or outer_mitmot.type.dtype or
inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1): inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_mitmot), str(outer_mitmot),
...@@ -487,9 +488,9 @@ class Scan(PureOp): ...@@ -487,9 +488,9 @@ class Scan(PureOp):
new_inputs.append(outer_mitsot) new_inputs.append(outer_mitsot)
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if (inner_mitsots[ipos + k].type.dtype != \ if (inner_mitsots[ipos + k].type.dtype !=
outer_mitsot.type.dtype or outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1): inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_mitsot), str(outer_mitsot),
...@@ -587,9 +588,8 @@ class Scan(PureOp): ...@@ -587,9 +588,8 @@ class Scan(PureOp):
# need to store. This input does not have the same dtype, nor is it the same # need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int. # type of tensor as the output, it is always a scalar int.
new_inputs += self.outer_nitsot(inputs) new_inputs += self.outer_nitsot(inputs)
for inner_nonseq, _outer_nonseq in zip( for inner_nonseq, _outer_nonseq in zip(self.inner_non_seqs(self.inputs),
self.inner_non_seqs(self.inputs), self.outer_non_seqs(inputs)):
self.outer_non_seqs(inputs)):
outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq) outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq) new_inputs.append(outer_nonseq)
if inner_nonseq.type != outer_nonseq.type: if inner_nonseq.type != outer_nonseq.type:
...@@ -602,7 +602,7 @@ class Scan(PureOp): ...@@ -602,7 +602,7 @@ class Scan(PureOp):
# depicts the size in memory for that sequence. This feature is # depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization # used by truncated BPTT and by scan space optimization
if (str(outer_nitsot.type.dtype)[:3] not in ('uin', 'int') or if (str(outer_nitsot.type.dtype)[:3] not in ('uin', 'int') or
outer_nitsot.ndim != 0): outer_nitsot.ndim != 0):
raise ValueError('For output %s you need to provide a ' raise ValueError('For output %s you need to provide a '
'scalar int !', str(outer_nitsot)) 'scalar int !', str(outer_nitsot))
assert len(new_inputs) == len(inputs) assert len(new_inputs) == len(inputs)
...@@ -630,9 +630,9 @@ class Scan(PureOp): ...@@ -630,9 +630,9 @@ class Scan(PureOp):
# Check if we are dealing with same type of objects # Check if we are dealing with same type of objects
if not type(self) == type(other): if not type(self) == type(other):
return False return False
if not 'destroy_map' in self.info: if 'destroy_map' not in self.info:
self.info['destroy_map'] = OrderedDict() self.info['destroy_map'] = OrderedDict()
if not 'destroy_map' in other.info: if 'destroy_map' not in other.info:
other.info['destroy_map'] = OrderedDict() other.info['destroy_map'] = OrderedDict()
keys_to_check = ['truncate_gradient', 'profile', keys_to_check = ['truncate_gradient', 'profile',
'n_seqs', 'tap_array', 'n_seqs', 'tap_array',
...@@ -675,7 +675,7 @@ class Scan(PureOp): ...@@ -675,7 +675,7 @@ class Scan(PureOp):
self.destroy_map = OrderedDict() self.destroy_map = OrderedDict()
if len(self.destroy_map.keys()) > 0: if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace # Check if all outputs are inplace
if (sorted(self.destroy_map.keys()) == \ if (sorted(self.destroy_map.keys()) ==
sorted(range(self.n_mit_mot + sorted(range(self.n_mit_mot +
self.n_mit_sot + self.n_mit_sot +
self.n_sit_sot))): self.n_sit_sot))):
...@@ -737,9 +737,6 @@ class Scan(PureOp): ...@@ -737,9 +737,6 @@ class Scan(PureOp):
node_input_storage = [storage_map[r] for r in node.inputs] node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs] node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs]
node_output_compute = [compute_map[r] for r in node.outputs]
#_logger.debug('Compiling node %i of graph' % node_idx)
# If a shared variable is the result of a ViewOp it is a clear # If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of # indication that we need to copy that value after the perform of
# scan is done # scan is done
...@@ -840,8 +837,8 @@ class Scan(PureOp): ...@@ -840,8 +837,8 @@ class Scan(PureOp):
profile = None profile = None
if (theano.config.profile or if (theano.config.profile or
(isinstance(self.profile, (string_types, bool, integer_types)) (isinstance(self.profile, (string_types, bool, integer_types)) and
and self.profile)): self.profile)):
if isinstance(self.profile, string_types): if isinstance(self.profile, string_types):
profile = ScanProfileStats(name=self.profile) profile = ScanProfileStats(name=self.profile)
else: else:
...@@ -916,32 +913,33 @@ class Scan(PureOp): ...@@ -916,32 +913,33 @@ class Scan(PureOp):
cython_destroy_map = numpy.asarray(cython_destroy_map, cython_destroy_map = numpy.asarray(cython_destroy_map,
dtype='int32') dtype='int32')
from . import scan_perform_ext from . import scan_perform_ext
p = lambda node, args, outs:\
scan_perform_ext.perform(self.n_shared_outs, def p(node, args, outs):
self.n_mit_mot_outs, return scan_perform_ext.perform(self.n_shared_outs,
self.n_seqs, self.n_mit_mot_outs,
self.n_mit_mot, self.n_seqs,
self.n_mit_sot, self.n_mit_mot,
self.n_sit_sot, self.n_mit_sot,
self.n_nit_sot, self.n_sit_sot,
args[0], self.n_nit_sot,
self.as_while, args[0],
cython_mintaps, self.as_while,
cython_tap_array, cython_mintaps,
cython_tap_array_len, cython_tap_array,
cython_vector_seqs, cython_tap_array_len,
cython_vector_outs, cython_vector_seqs,
cython_mit_mot_out_slices, cython_vector_outs,
cython_mit_mot_out_nslices, cython_mit_mot_out_slices,
cython_mitmots_preallocated, cython_mit_mot_out_nslices,
cython_inps_is_tensor, cython_mitmots_preallocated,
cython_outs_is_tensor, cython_inps_is_tensor,
self.fn.fn, cython_outs_is_tensor,
self.fn, self.fn.fn,
cython_destroy_map, self.fn,
args, cython_destroy_map,
outs, args,
self, node) outs,
self, node)
except (ImportError, theano.gof.cmodule.MissingGXX): except (ImportError, theano.gof.cmodule.MissingGXX):
p = self.execute p = self.execute
# default arguments are stored in the closure of `rval` # default arguments are stored in the closure of `rval`
...@@ -1183,8 +1181,8 @@ class Scan(PureOp): ...@@ -1183,8 +1181,8 @@ class Scan(PureOp):
outs[idx][0] = args[self.seqs_arg_offset + idx] outs[idx][0] = args[self.seqs_arg_offset + idx]
elif (outs[idx][0] is not None and elif (outs[idx][0] is not None and
outs[idx][0].shape[1:] == args[self.seqs_arg_offset + outs[idx][0].shape[1:] == args[self.seqs_arg_offset +
idx].shape[1:] idx].shape[1:] and
and outs[idx][0].shape[0] >= store_steps[idx]): outs[idx][0].shape[0] >= store_steps[idx]):
# Put in the values of the initial state # Put in the values of the initial state
outs[idx][0] = outs[idx][0][:store_steps[idx]] outs[idx][0] = outs[idx][0][:store_steps[idx]]
if idx > self.n_mit_mot: if idx > self.n_mit_mot:
...@@ -1212,7 +1210,7 @@ class Scan(PureOp): ...@@ -1212,7 +1210,7 @@ class Scan(PureOp):
i = 0 i = 0
cond = True cond = True
############## THE MAIN LOOP ######################### # ############# THE MAIN LOOP ##############
# for i in xrange(n_steps): # for i in xrange(n_steps):
while (i < n_steps) and cond: while (i < n_steps) and cond:
# sequences over which scan iterates # sequences over which scan iterates
...@@ -1263,7 +1261,7 @@ class Scan(PureOp): ...@@ -1263,7 +1261,7 @@ class Scan(PureOp):
for idx in xrange(self.n_outs + self.n_nit_sot - for idx in xrange(self.n_outs + self.n_nit_sot -
self.n_mit_mot): self.n_mit_mot):
if (store_steps[idx + self.n_mit_mot] == 1 or if (store_steps[idx + self.n_mit_mot] == 1 or
self.vector_outs[idx + self.n_mit_mot]): self.vector_outs[idx + self.n_mit_mot]):
output_storage[idx + offset].storage[0] = None output_storage[idx + offset].storage[0] = None
else: else:
_pos0 = idx + self.n_mit_mot _pos0 = idx + self.n_mit_mot
...@@ -1434,8 +1432,17 @@ class Scan(PureOp): ...@@ -1434,8 +1432,17 @@ class Scan(PureOp):
output_reused = False output_reused = False
if not output_reused: if not output_reused:
outs[j][0][pos[j]] = \ try:
output_storage[offset_out + j].storage[0] outs[j][0][pos[j]] = \
output_storage[offset_out + j].storage[0]
except ValueError as e:
ne = ValueError(
"An output of the scan has changed shape. "
"This may be caused by a pushout optimization."
" Try adding "
"'optimizer_excluding=scanOp_pushout_output' "
"to your Theano flags.")
raise_from(ne, e)
# 5.5 Copy over the values for nit_sot outputs # 5.5 Copy over the values for nit_sot outputs
begin = end begin = end
...@@ -1497,7 +1504,7 @@ class Scan(PureOp): ...@@ -1497,7 +1504,7 @@ class Scan(PureOp):
end = self.n_outs + self.n_nit_sot end = self.n_outs + self.n_nit_sot
for idx in xrange(begin, end): for idx in xrange(begin, end):
if (store_steps[idx] < i - self.mintaps[idx] and if (store_steps[idx] < i - self.mintaps[idx] and
pos[idx] < store_steps[idx]): pos[idx] < store_steps[idx]):
pdx = pos[idx] pdx = pos[idx]
if pdx >= store_steps[idx] // 2: if pdx >= store_steps[idx] // 2:
...@@ -1752,7 +1759,7 @@ class Scan(PureOp): ...@@ -1752,7 +1759,7 @@ class Scan(PureOp):
j_inp_idx = self.var_mappings["outer_inp_from_outer_out"][jidx] j_inp_idx = self.var_mappings["outer_inp_from_outer_out"][jidx]
if j_inp_idx != -1: if j_inp_idx != -1:
if connection_pattern[j_inp_idx][iidx] == True: if connection_pattern[j_inp_idx][iidx] is True:
for k in xrange(len(connection_pattern)): for k in xrange(len(connection_pattern)):
if connection_pattern[k][jidx]: if connection_pattern[k][jidx]:
connection_pattern[k][iidx] = True connection_pattern[k][iidx] = True
...@@ -2106,7 +2113,6 @@ class Scan(PureOp): ...@@ -2106,7 +2113,6 @@ class Scan(PureOp):
dc_dxts_idx += 1 dc_dxts_idx += 1
dC_dinps_t = compute_all_gradients(known_grads) dC_dinps_t = compute_all_gradients(known_grads)
# mask inputs that get no gradients # mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)): for dx in xrange(len(dC_dinps_t)):
if not dC_dinps_t[dx]: if not dC_dinps_t[dx]:
...@@ -2159,7 +2165,6 @@ class Scan(PureOp): ...@@ -2159,7 +2165,6 @@ class Scan(PureOp):
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]] outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
for idx in xrange(self.n_mit_mot + self.n_mit_sot): for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = numpy.min(self.tap_array[idx]) mintap = numpy.min(self.tap_array[idx])
maxtap = numpy.max(self.tap_array[idx])
if idx < self.n_mit_mot: if idx < self.n_mit_mot:
outmaxtap = numpy.max(self.mitmot_out_taps()[idx]) outmaxtap = numpy.max(self.mitmot_out_taps()[idx])
else: else:
...@@ -2205,7 +2210,7 @@ class Scan(PureOp): ...@@ -2205,7 +2210,7 @@ class Scan(PureOp):
# Restrict the length of the outer sequences to the number of grad # Restrict the length of the outer sequences to the number of grad
# steps # steps
outer_inp_seqs = [seq[:grad_steps] for seq in outer_inp_seqs] outer_inp_seqs = [s_[:grad_steps] for s_ in outer_inp_seqs]
inner_inp_seqs = self.inner_seqs(self_inputs) inner_inp_seqs = self.inner_seqs(self_inputs)
inner_inp_seqs += self.inner_mitmot(self_inputs) inner_inp_seqs += self.inner_mitmot(self_inputs)
...@@ -2215,7 +2220,6 @@ class Scan(PureOp): ...@@ -2215,7 +2220,6 @@ class Scan(PureOp):
inner_inp_seqs += Xts inner_inp_seqs += Xts
# mitmot # mitmot
outer_inp_mitmot = [] outer_inp_mitmot = []
outer_out_mitmot = []
inner_inp_mitmot = [] inner_inp_mitmot = []
inner_out_mitmot = [] inner_out_mitmot = []
mitmot_inp_taps = [] mitmot_inp_taps = []
...@@ -2496,11 +2500,10 @@ class Scan(PureOp): ...@@ -2496,11 +2500,10 @@ class Scan(PureOp):
outer_inp_seqs + outer_inp_seqs +
outer_inp_mitmot + outer_inp_mitmot +
outer_inp_sitsot + outer_inp_sitsot +
[inputs[0] for x in xrange(n_nit_sot)] + [inputs[0] for _ in xrange(n_nit_sot)] +
self.outer_shared(inputs) + self.outer_shared(inputs) +
self.outer_non_seqs(inputs)) self.outer_non_seqs(inputs))
inner_other_args = self_inputs[offset:]
inner_gfn_ins = (inner_inp_seqs + inner_gfn_ins = (inner_inp_seqs +
inner_inp_mitmot + inner_inp_mitmot +
inner_inp_sitsot + inner_inp_sitsot +
...@@ -2704,7 +2707,7 @@ class Scan(PureOp): ...@@ -2704,7 +2707,7 @@ class Scan(PureOp):
e = e + self.n_mit_sot e = e + self.n_mit_sot
ib = ie ib = ie
ie = ie + int(numpy.sum([len(x) for x in ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot: \ self.tap_array[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]])) self.n_mit_mot + self.n_mit_sot]]))
clean_eval_points = [] clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]): for inp, evp in zip(inputs[b:e], eval_points[b:e]):
...@@ -2827,7 +2830,7 @@ gof.ops_with_inner_function[Scan] = 'fn' ...@@ -2827,7 +2830,7 @@ gof.ops_with_inner_function[Scan] = 'fn'
# TODO: move that to the new back-end and new profiling.py print_tips # TODO: move that to the new back-end and new profiling.py print_tips
#@theano.compile.profilemode.register_profiler_printer # @theano.compile.profilemode.register_profiler_printer
def profile_printer(fct_name, compile_time, fct_call_time, fct_call, def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time, apply_cimpl, message, outputs_size, apply_time, apply_cimpl, message, outputs_size,
other_time): other_time):
...@@ -2836,9 +2839,9 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call, ...@@ -2836,9 +2839,9 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time.items()]): apply_time.items()]):
print() print()
print('Scan overhead:') print('Scan overhead:')
print ('<Scan op time(s)> <sub scan fct time(s)> <sub scan op ' print('<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(% scan op time)> <sub scan ' 'time(s)> <sub scan fct time(% scan op time)> <sub scan '
'op time(% scan op time)> <node>') 'op time(% scan op time)> <node>')
total_super_scan_time = 0 total_super_scan_time = 0
total_scan_fct_time = 0 total_scan_fct_time = 0
total_scan_op_time = 0 total_scan_op_time = 0
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -62,7 +62,7 @@ import copy ...@@ -62,7 +62,7 @@ import copy
def get_version(): def get_version():
return 0.293 return 0.294
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -544,7 +544,15 @@ def perform( ...@@ -544,7 +544,15 @@ def perform(
output_reused = False output_reused = False
if not output_reused: if not output_reused:
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0] try:
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0]
except ValueError as e:
raise ValueError(
"An output of the scan has changed shape. "
"This may be caused by a pushout optimization."
" Try adding "
"'optimizer_excluding=scanOp_pushout_output' "
"to your Theano flags.")
# 5.6 Copy over the values for outputs corresponding to shared # 5.6 Copy over the values for outputs corresponding to shared
# variables # variables
......
...@@ -17,7 +17,7 @@ from theano.gof import cmodule ...@@ -17,7 +17,7 @@ from theano.gof import cmodule
_logger = logging.getLogger('theano.scan_module.scan_perform') _logger = logging.getLogger('theano.scan_module.scan_perform')
version = 0.293 # must match constant returned in function get_version() version = 0.294 # must match constant returned in function get_version()
need_reload = False need_reload = False
...@@ -94,7 +94,7 @@ except ImportError: ...@@ -94,7 +94,7 @@ except ImportError:
# the old interface. # the old interface.
if False: if False:
# During scan cython development, it is helpful to keep the old interface, to don't manually edit the c file each time. # During scan cython development, it is helpful to keep the old interface, to don't manually edit the c file each time.
preargs.remove('-D NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION') preargs.remove('-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION')
else: else:
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]] numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
# Add add some macro to lower the number of edit # Add add some macro to lower the number of edit
...@@ -102,13 +102,13 @@ except ImportError: ...@@ -102,13 +102,13 @@ except ImportError:
if bool(numpy_ver >= [1, 7]): if bool(numpy_ver >= [1, 7]):
# Needed when we disable the old API, as cython # Needed when we disable the old API, as cython
# use the old interface # use the old interface
preargs.append("-D NPY_ENSUREARRAY=NPY_ARRAY_ENSUREARRAY") preargs.append("-DNPY_ENSUREARRAY=NPY_ARRAY_ENSUREARRAY")
preargs.append("-D NPY_ENSURECOPY=NPY_ARRAY_ENSURECOPY") preargs.append("-DNPY_ENSURECOPY=NPY_ARRAY_ENSURECOPY")
preargs.append("-D NPY_ALIGNED=NPY_ARRAY_ALIGNED") preargs.append("-DNPY_ALIGNED=NPY_ARRAY_ALIGNED")
preargs.append("-D NPY_WRITEABLE=NPY_ARRAY_WRITEABLE") preargs.append("-DNPY_WRITEABLE=NPY_ARRAY_WRITEABLE")
preargs.append("-D NPY_UPDATE_ALL=NPY_ARRAY_UPDATE_ALL") preargs.append("-DNPY_UPDATE_ALL=NPY_ARRAY_UPDATE_ALL")
preargs.append("-D NPY_C_CONTIGUOUS=NPY_ARRAY_C_CONTIGUOUS") preargs.append("-DNPY_C_CONTIGUOUS=NPY_ARRAY_C_CONTIGUOUS")
preargs.append("-D NPY_F_CONTIGUOUS=NPY_ARRAY_F_CONTIGUOUS") preargs.append("-DNPY_F_CONTIGUOUS=NPY_ARRAY_F_CONTIGUOUS")
cmodule.GCC_compiler.compile_str(dirname, code, location=loc, cmodule.GCC_compiler.compile_str(dirname, code, location=loc,
preargs=preargs, preargs=preargs,
......
...@@ -213,7 +213,7 @@ class TestPushOutScanOutputDot(object): ...@@ -213,7 +213,7 @@ class TestPushOutScanOutputDot(object):
# not be the result of a Dot # not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort() scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0] if isinstance(node.op, Scan)][0]
# NOTE: WHEN INFER_SHAPE IS REENABLED, BELLOW THE SCAN MUST # NOTE: WHEN INFER_SHAPE IS REENABLED, BELOW THE SCAN MUST
# HAVE ONLY 1 OUTPUT. # HAVE ONLY 1 OUTPUT.
assert len(scan_node.op.outputs) == 2 assert len(scan_node.op.outputs) == 2
assert not isinstance(scan_node.op.outputs[0], T.Dot) assert not isinstance(scan_node.op.outputs[0], T.Dot)
......
...@@ -88,7 +88,6 @@ whitelist_flake8 = [ ...@@ -88,7 +88,6 @@ whitelist_flake8 = [
"scan_module/scan_utils.py", "scan_module/scan_utils.py",
"scan_module/scan_views.py", "scan_module/scan_views.py",
"scan_module/scan.py", "scan_module/scan.py",
"scan_module/scan_op.py",
"scan_module/scan_perform_ext.py", "scan_module/scan_perform_ext.py",
"scan_module/__init__.py", "scan_module/__init__.py",
"scan_module/tests/__init__.py", "scan_module/tests/__init__.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论