提交 639d27b0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

In get_constant_value, check if an array has one unique value

上级 867e6981
...@@ -418,14 +418,15 @@ def get_constant_value(v): ...@@ -418,14 +418,15 @@ def get_constant_value(v):
""" """
if isinstance(v, Constant): if isinstance(v, Constant):
#TODO: consider checking for arrays of the form e.g. [1,1,1,1] where if getattr(v.tag, 'unique_value', None) is not None:
# it is not a constant, but in some cases it *could* be replaced with one. data = v.tag.unique_value
# Note that this would have an effect on the broadcasting of inputs and so on else:
data = v.data
try: try:
numpy.complex(v.data) #works for all numeric scalars numpy.complex(data) #works for all numeric scalars
return v.data return data
except: except:
raise TypeError('v.data is non-numeric or non-scalar', v) raise TypeError('v.data is non-numeric, non-scalar, or has more than one unique value', v)
if v.owner: if v.owner:
if isinstance(v.owner.op, Alloc): if isinstance(v.owner.op, Alloc):
return get_constant_value(v.owner.inputs[0]) return get_constant_value(v.owner.inputs[0])
...@@ -5342,7 +5343,7 @@ def Rop(f, wrt, eval_points): ...@@ -5342,7 +5343,7 @@ def Rop(f, wrt, eval_points):
if len(rval) == 1: if len(rval) == 1:
return rval[0] return rval[0]
else: else:
return rval return rval
def Lop(f, wrt, eval_points, consider_constant=[], warn_type=False, def Lop(f, wrt, eval_points, consider_constant=[], warn_type=False,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论