提交 63bfae75 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bug where get_scalar_constant_value did not detect numpy ndarrays

as being constant
上级 204e3640
...@@ -487,11 +487,14 @@ def get_scalar_constant_value(v): ...@@ -487,11 +487,14 @@ def get_scalar_constant_value(v):
# on passing it None) # on passing it None)
raise NotScalarConstantError() raise NotScalarConstantError()
if isinstance(v, Constant): if isinstance(v, (int, float)):
if getattr(v.tag, 'unique_value', None) is not None: return v
data = v.tag.unique_value
else: def numpy_scalar(n):
data = v.data """ Return a scalar stored in a numpy ndarray, or raise
NotScalarConstantError if the numpy ndarray is not a scalar
"""
# handle case where data is numpy.array([]) # handle case where data is numpy.array([])
if hasattr(data, 'shape') and len(data.shape) == 0 or \ if hasattr(data, 'shape') and len(data.shape) == 0 or \
__builtins__['max'](data.shape) == 0: __builtins__['max'](data.shape) == 0:
...@@ -503,7 +506,18 @@ def get_scalar_constant_value(v): ...@@ -503,7 +506,18 @@ def get_scalar_constant_value(v):
except Exception: except Exception:
raise NotScalarConstantError( raise NotScalarConstantError(
'v.data is non-numeric, non-scalar, or has more than one' 'v.data is non-numeric, non-scalar, or has more than one'
' unique value', v) ' unique value', n)
if isinstance(v, numpy.ndarray):
return numpy_scalar(v)
if isinstance(v, Constant):
if getattr(v.tag, 'unique_value', None) is not None:
data = v.tag.unique_value
else:
data = v.data
return numpy_scalar(data)
if v.owner: if v.owner:
if isinstance(v.owner.op, Alloc): if isinstance(v.owner.op, Alloc):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论