提交 fd2ad1c4 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add theano flag on_shape_error.

上级 8186a4f0
......@@ -164,6 +164,7 @@ Others:
* Filtering update. (James)
* The buidbot now raises optimization errors instead of just printing a warning. (Frederic)
* On Windows, the default compiledir changed to be local to the computer/user and not transferred with roaming profile. (Sebastian Urban)
* New theano flag "on_shape_error". Default to "warn" (same as previous behavior): it print a warning when an error occur when infering the shape of some apply node. The other accepted value is "raise" to raise an error when this happen.
Reviewers (alphabetical order):
* David, Frederic, Ian, James, Olivier, Razvan
......@@ -243,6 +243,16 @@ import theano and print the config variable, as in:
the user and skip this optimization ('warn'), or raise the exception
('raise').
.. attribute:: on_shape_error
String value: 'warn' or 'raise'
Default: 'warn'
When an exception is raised when infering the shape of some apply
node, either warn the user and use a default value ('warn'), or
raise the exception ('raise').
.. attribute:: config.warn.ignore_bug_before
String value: 'None', 'all', '0.3', '0.4', '0.4.1', '0.5'
......
......@@ -34,6 +34,12 @@ from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value, ShapeError
theano.configparser.AddConfigVar('on_shape_error',
"warn: print a warning and use the default"
" value. raise: raise an error",
theano.configparser.EnumStr("warn", "raise"),
in_c_key=False)
# Utilities
......@@ -906,11 +912,15 @@ class ShapeFeature(object):
'supported, and one should now use tensor.ShapeError '
'instead. The original exception message is: %s' % e)
except Exception, e:
_logger.error(('Failed to infer_shape from Op %s.\nInput shapes:'
'%s\nException encountered during infer_shape: '
'%s\nException message: %s\nTraceback: %s') %
(node.op, [self.shape_of[r] for r in node.inputs],
type(e), str(e), traceback.format_exc()))
msg = ('Failed to infer_shape from Op %s.\nInput shapes:'
'%s\nException encountered during infer_shape: '
'%s\nException message: %s\nTraceback: %s') % (
node.op, [self.shape_of[r] for r in node.inputs],
type(e), str(e), traceback.format_exc())
if config.on_shape_error == "raise":
raise Exception(msg)
else:
_logger.error(msg)
o_shapes = self.default_infer_shape(
node, [self.shape_of[r] for r in node.inputs])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论