提交 c3b5c808 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added NotConstantError so get_constant_value can raise TypeError if

given bad input
上级 6a023bd4
...@@ -462,13 +462,27 @@ def _allclose(a, b, rtol=None, atol=None): ...@@ -462,13 +462,27 @@ def _allclose(a, b, rtol=None, atol=None):
return numpy.allclose(a, b, atol=atol_, rtol=rtol_) return numpy.allclose(a, b, atol=atol_, rtol=rtol_)
class NotConstantError(TypeError):
"""
Raised by get_constant_value if called on something that is
not constant.
For now it is a TypeError, to maintain the old interface
that get_constant_value should raise a TypeError in this
situation. However, this is unsafe because get_constant_value
could inadvertently raise a TypeError if it has a bug.
So we should eventually make NotConstantError derive
from Exception directly, and modify all code that uses
get_constant_value to catch this more specific exception.
"""
pass
def get_constant_value(v): def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
this function digs through them. this function digs through them.
If `v` is not some view of constant data, then raise a TypeError. If `v` is not some view of constant data, then raise a NotConstantError.
:note: There may be another function similar to this one in the :note: There may be another function similar to this one in the
code, but I'm not sure where it is. code, but I'm not sure where it is.
...@@ -488,7 +502,7 @@ def get_constant_value(v): ...@@ -488,7 +502,7 @@ def get_constant_value(v):
numpy.complex(data) # works for all numeric scalars numpy.complex(data) # works for all numeric scalars
return data return data
except Exception: except Exception:
raise TypeError( raise NotConstantError(
'v.data is non-numeric, non-scalar, or has more than one' 'v.data is non-numeric, non-scalar, or has more than one'
' unique value', v) ' unique value', v)
if v.owner: if v.owner:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论