提交 e8d45dc8 authored 作者: Frederic's avatar Frederic

Move numpy_scalar outside of the function that use it.

Indent what is inside. I'll use that in the next commit. This will make the change more visible.
上级 4a77221b
...@@ -508,6 +508,24 @@ class EmptyConstantError(NotScalarConstantError): ...@@ -508,6 +508,24 @@ class EmptyConstantError(NotScalarConstantError):
""" """
def numpy_scalar(data):
""" Return a scalar stored in a numpy ndarray, or raise
NotScalarConstantError if the numpy ndarray is not a scalar
"""
# handle case where data is numpy.array([])
if data.ndim > 0 and (len(data.shape) == 0 or
__builtins__['max'](data.shape) == 0):
assert numpy.all(numpy.array([]) == data)
raise EmptyConstantError()
try:
numpy.complex(data) # works for all numeric scalars
return data
except Exception:
raise NotScalarConstantError(
'v.data is non-numeric, non-scalar, or has more than one'
' unique value', data)
get_scalar_constant_value_elemwises = ( get_scalar_constant_value_elemwises = (
scal.Cast, scal.Switch, scal.Cast, scal.Switch,
scal.NEQ, scal.EQ, scal.NEQ, scal.EQ,
...@@ -526,7 +544,7 @@ def get_scalar_constant_value(v): ...@@ -526,7 +544,7 @@ def get_scalar_constant_value(v):
:note: There may be another function similar to this one in the :note: There may be another function similar to this one in the
code, but I'm not sure where it is. code, but I'm not sure where it is.
""" """
if True:
if v is None: if v is None:
# None is not a scalar (and many uses of this function seem to depend # None is not a scalar (and many uses of this function seem to depend
# on passing it None) # on passing it None)
...@@ -535,24 +553,6 @@ def get_scalar_constant_value(v): ...@@ -535,24 +553,6 @@ def get_scalar_constant_value(v):
if isinstance(v, (numpy.integer, int, float)): if isinstance(v, (numpy.integer, int, float)):
return numpy.asarray(v) return numpy.asarray(v)
def numpy_scalar(data):
""" Return a scalar stored in a numpy ndarray, or raise
NotScalarConstantError if the numpy ndarray is not a scalar
"""
# handle case where data is numpy.array([])
if data.ndim > 0 and (len(data.shape) == 0 or
__builtins__['max'](data.shape) == 0):
assert numpy.all(numpy.array([]) == data)
raise EmptyConstantError()
try:
numpy.complex(data) # works for all numeric scalars
return data
except Exception:
raise NotScalarConstantError(
'v.data is non-numeric, non-scalar, or has more than one'
' unique value', data)
if isinstance(v, numpy.ndarray): if isinstance(v, numpy.ndarray):
return numpy_scalar(v) return numpy_scalar(v)
...@@ -568,8 +568,8 @@ def get_scalar_constant_value(v): ...@@ -568,8 +568,8 @@ def get_scalar_constant_value(v):
compile.ops.OutputGuard, compile.ops.OutputGuard,
compile.DeepCopyOp)): compile.DeepCopyOp)):
return get_scalar_constant_value(v.owner.inputs[0]) return get_scalar_constant_value(v.owner.inputs[0])
elif (isinstance(v.owner.op, theano.compile.ops.Shape_i) and elif isinstance(v.owner.op, theano.compile.ops.Shape_i):
isinstance(v.owner.inputs[0], Constant)): if isinstance(v.owner.inputs[0], Constant):
return v.owner.inputs[0].data.shape[v.owner.op.i] return v.owner.inputs[0].data.shape[v.owner.op.i]
# Don't act as the constant_folding optimization here as this # Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would # fct is used too early in the optimization phase. This would
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论