提交 4d6ec42a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Also add the message in C.

上级 d53d21a6
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -62,7 +62,7 @@ import copy ...@@ -62,7 +62,7 @@ import copy
def get_version(): def get_version():
return 0.293 return 0.294
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -544,7 +544,15 @@ def perform( ...@@ -544,7 +544,15 @@ def perform(
output_reused = False output_reused = False
if not output_reused: if not output_reused:
try:
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0] outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0]
except ValueError as e:
raise ValueError(
"An output of the scan has changed shape. "
"This may be caused by a pushout optimization."
" Try adding "
"'optimizer_excluding=scanOp_pushout_output' "
"to your Theano flags.")
# 5.6 Copy over the values for outputs corresponding to shared # 5.6 Copy over the values for outputs corresponding to shared
# variables # variables
......
...@@ -17,7 +17,7 @@ from theano.gof import cmodule ...@@ -17,7 +17,7 @@ from theano.gof import cmodule
_logger = logging.getLogger('theano.scan_module.scan_perform') _logger = logging.getLogger('theano.scan_module.scan_perform')
version = 0.293 # must match constant returned in function get_version() version = 0.294 # must match constant returned in function get_version()
need_reload = False need_reload = False
...@@ -94,7 +94,7 @@ except ImportError: ...@@ -94,7 +94,7 @@ except ImportError:
# the old interface. # the old interface.
if False: if False:
# During scan cython development, it is helpful to keep the old interface, to don't manually edit the c file each time. # During scan cython development, it is helpful to keep the old interface, to don't manually edit the c file each time.
preargs.remove('-D NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION') preargs.remove('-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION')
else: else:
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]] numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
# Add add some macro to lower the number of edit # Add add some macro to lower the number of edit
...@@ -102,13 +102,13 @@ except ImportError: ...@@ -102,13 +102,13 @@ except ImportError:
if bool(numpy_ver >= [1, 7]): if bool(numpy_ver >= [1, 7]):
# Needed when we disable the old API, as cython # Needed when we disable the old API, as cython
# use the old interface # use the old interface
preargs.append("-D NPY_ENSUREARRAY=NPY_ARRAY_ENSUREARRAY") preargs.append("-DNPY_ENSUREARRAY=NPY_ARRAY_ENSUREARRAY")
preargs.append("-D NPY_ENSURECOPY=NPY_ARRAY_ENSURECOPY") preargs.append("-DNPY_ENSURECOPY=NPY_ARRAY_ENSURECOPY")
preargs.append("-D NPY_ALIGNED=NPY_ARRAY_ALIGNED") preargs.append("-DNPY_ALIGNED=NPY_ARRAY_ALIGNED")
preargs.append("-D NPY_WRITEABLE=NPY_ARRAY_WRITEABLE") preargs.append("-DNPY_WRITEABLE=NPY_ARRAY_WRITEABLE")
preargs.append("-D NPY_UPDATE_ALL=NPY_ARRAY_UPDATE_ALL") preargs.append("-DNPY_UPDATE_ALL=NPY_ARRAY_UPDATE_ALL")
preargs.append("-D NPY_C_CONTIGUOUS=NPY_ARRAY_C_CONTIGUOUS") preargs.append("-DNPY_C_CONTIGUOUS=NPY_ARRAY_C_CONTIGUOUS")
preargs.append("-D NPY_F_CONTIGUOUS=NPY_ARRAY_F_CONTIGUOUS") preargs.append("-DNPY_F_CONTIGUOUS=NPY_ARRAY_F_CONTIGUOUS")
cmodule.GCC_compiler.compile_str(dirname, code, location=loc, cmodule.GCC_compiler.compile_str(dirname, code, location=loc,
preargs=preargs, preargs=preargs,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论