提交 697592d9 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

raise an error on empty constants instead of returning them

上级 63bfae75
...@@ -466,9 +466,14 @@ def _allclose(a, b, rtol=None, atol=None): ...@@ -466,9 +466,14 @@ def _allclose(a, b, rtol=None, atol=None):
class NotScalarConstantError(Exception): class NotScalarConstantError(Exception):
""" """
Raised by get_scalar_constant_value if called on something that is Raised by get_scalar_constant_value if called on something that is
not constant. not a scalar constant.
"""
class EmptyConstantError(NotScalarConstantError):
"""
Raised by get_scalar_const_value if called on something that is a
zero dimensional constant.
""" """
pass
def get_scalar_constant_value(v): def get_scalar_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
...@@ -488,7 +493,7 @@ def get_scalar_constant_value(v): ...@@ -488,7 +493,7 @@ def get_scalar_constant_value(v):
raise NotScalarConstantError() raise NotScalarConstantError()
if isinstance(v, (int, float)): if isinstance(v, (int, float)):
return v return numpy.asarray(v)
def numpy_scalar(n): def numpy_scalar(n):
""" Return a scalar stored in a numpy ndarray, or raise """ Return a scalar stored in a numpy ndarray, or raise
...@@ -496,10 +501,10 @@ def get_scalar_constant_value(v): ...@@ -496,10 +501,10 @@ def get_scalar_constant_value(v):
""" """
# handle case where data is numpy.array([]) # handle case where data is numpy.array([])
if hasattr(data, 'shape') and len(data.shape) == 0 or \ if data.ndim > 0 and (len(data.shape) == 0 or
__builtins__['max'](data.shape) == 0: __builtins__['max'](data.shape) == 0):
assert numpy.all(numpy.array([]) == data) assert numpy.all(numpy.array([]) == data)
return data raise EmptyConstantError()
try: try:
numpy.complex(data) # works for all numeric scalars numpy.complex(data) # works for all numeric scalars
return data return data
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论