提交 e8d45dc8 authored 作者: Frederic's avatar Frederic

Move numpy_scalar outside of the function that use it.

Indent what is inside. I'll use that in the next commit. This will make the change more visible.
上级 4a77221b
......@@ -508,6 +508,24 @@ class EmptyConstantError(NotScalarConstantError):
"""
def numpy_scalar(data):
""" Return a scalar stored in a numpy ndarray, or raise
NotScalarConstantError if the numpy ndarray is not a scalar
"""
# handle case where data is numpy.array([])
if data.ndim > 0 and (len(data.shape) == 0 or
__builtins__['max'](data.shape) == 0):
assert numpy.all(numpy.array([]) == data)
raise EmptyConstantError()
try:
numpy.complex(data) # works for all numeric scalars
return data
except Exception:
raise NotScalarConstantError(
'v.data is non-numeric, non-scalar, or has more than one'
' unique value', data)
get_scalar_constant_value_elemwises = (
scal.Cast, scal.Switch,
scal.NEQ, scal.EQ,
......@@ -526,7 +544,7 @@ def get_scalar_constant_value(v):
:note: There may be another function similar to this one in the
code, but I'm not sure where it is.
"""
if True:
if v is None:
# None is not a scalar (and many uses of this function seem to depend
# on passing it None)
......@@ -535,24 +553,6 @@ def get_scalar_constant_value(v):
if isinstance(v, (numpy.integer, int, float)):
return numpy.asarray(v)
def numpy_scalar(data):
""" Return a scalar stored in a numpy ndarray, or raise
NotScalarConstantError if the numpy ndarray is not a scalar
"""
# handle case where data is numpy.array([])
if data.ndim > 0 and (len(data.shape) == 0 or
__builtins__['max'](data.shape) == 0):
assert numpy.all(numpy.array([]) == data)
raise EmptyConstantError()
try:
numpy.complex(data) # works for all numeric scalars
return data
except Exception:
raise NotScalarConstantError(
'v.data is non-numeric, non-scalar, or has more than one'
' unique value', data)
if isinstance(v, numpy.ndarray):
return numpy_scalar(v)
......@@ -568,8 +568,8 @@ def get_scalar_constant_value(v):
compile.ops.OutputGuard,
compile.DeepCopyOp)):
return get_scalar_constant_value(v.owner.inputs[0])
elif (isinstance(v.owner.op, theano.compile.ops.Shape_i) and
isinstance(v.owner.inputs[0], Constant)):
elif isinstance(v.owner.op, theano.compile.ops.Shape_i):
if isinstance(v.owner.inputs[0], Constant):
return v.owner.inputs[0].data.shape[v.owner.op.i]
# Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论