提交 697592d9 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

raise an error on empty constants instead of returning them

上级 63bfae75
......@@ -466,9 +466,14 @@ def _allclose(a, b, rtol=None, atol=None):
class NotScalarConstantError(Exception):
"""
Raised by get_scalar_constant_value if called on something that is
not constant.
not a scalar constant.
"""
class EmptyConstantError(NotScalarConstantError):
"""
Raised by get_scalar_const_value if called on something that is a
zero dimensional constant.
"""
pass
def get_scalar_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
......@@ -488,7 +493,7 @@ def get_scalar_constant_value(v):
raise NotScalarConstantError()
if isinstance(v, (int, float)):
return v
return numpy.asarray(v)
def numpy_scalar(n):
""" Return a scalar stored in a numpy ndarray, or raise
......@@ -496,10 +501,10 @@ def get_scalar_constant_value(v):
"""
# handle case where data is numpy.array([])
if hasattr(data, 'shape') and len(data.shape) == 0 or \
__builtins__['max'](data.shape) == 0:
if data.ndim > 0 and (len(data.shape) == 0 or
__builtins__['max'](data.shape) == 0):
assert numpy.all(numpy.array([]) == data)
return data
raise EmptyConstantError()
try:
numpy.complex(data) # works for all numeric scalars
return data
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论