提交 ba9ffef6 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Support for get_constant to return ndarrays not only scalars.

上级 600c40e7
...@@ -409,9 +409,12 @@ def _allclose(a, b): ...@@ -409,9 +409,12 @@ def _allclose(a, b):
return numpy.allclose(a,b, atol=atol, rtol=rtol) return numpy.allclose(a,b, atol=atol, rtol=rtol)
def get_constant_value(v): def get_constant_value(v, return_ndarray = False):
"""return the constant scalar(0-D) value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
If ``return_ndarray`` is True, it also returns numpy ndarrays if this is
the content of the constant.
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
this function digs through them. this function digs through them.
...@@ -426,10 +429,12 @@ def get_constant_value(v): ...@@ -426,10 +429,12 @@ def get_constant_value(v):
# it is not a constant, but in some cases it *could* be replaced with one. # it is not a constant, but in some cases it *could* be replaced with one.
# Note that this would have an effect on the broadcasting of inputs and so on # Note that this would have an effect on the broadcasting of inputs and so on
try: try:
if return_ndarray and isinstance(v.data, numpy.ndarray):
return v.data
numpy.complex(v.data) #works for all numeric scalars numpy.complex(v.data) #works for all numeric scalars
return v.data return v.data
except: except:
raise TypeError('v.data is non-numeric', v) raise TypeError('v.data is non-numeric or non-scalar', v)
if v.owner: if v.owner:
if isinstance(v.owner.op, Alloc): if isinstance(v.owner.op, Alloc):
return get_constant_value(v.owner.inputs[0]) return get_constant_value(v.owner.inputs[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论