提交 82642e99 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

better name

上级 73005085
......@@ -163,7 +163,7 @@ def dot(l, r):
return rval
def get_constant_value(v):
def get_scalar_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
......@@ -172,12 +172,12 @@ def get_constant_value(v):
If theano.sparse is also there, we will look over CSM op.
If `v` is not some view of constant data, then raise a
tensor.basic.NotConstantError.
tensor.basic.NotScalarConstantError.
"""
if hasattr(theano, 'sparse') and isinstance(v.type,
theano.sparse.SparseType):
if v.owner is not None and isinstance(v.owner.op,
theano.sparse.CSM):
data = v.owner.inputs[0]
return tensor.get_constant_value(data)
return tensor.get_constant_value(v)
return tensor.get_scalar_constant_value(data)
return tensor.get_scalar_constant_value(v)
......@@ -975,7 +975,7 @@ def _populate_grad_dict(var_to_app_to_idx,
msg += "%s."
msg % (str(node.op), str(term), str(type(term)),
i, str(theano.get_constant_value(term)))
i, str(theano.get_scalar_constant_value(term)))
raise ValueError(msg)
......@@ -1616,9 +1616,9 @@ def _is_zero(x):
no_constant_value = True
try:
constant_value = theano.get_constant_value(x)
constant_value = theano.get_scalar_constant_value(x)
no_constant_value = False
except theano.tensor.basic.NotConstantError:
except theano.tensor.basic.NotScalarConstantError:
pass
if no_constant_value:
......
......@@ -2685,7 +2685,7 @@ class GpuAlloc(GpuOp):
raise TypeError('Shape arguments must be integers', s)
# if s is constant 1, then we're broadcastable in that dim
try:
const_shp = tensor.get_constant_value(s)
const_shp = tensor.get_scalar_constant_value(s)
except TypeError:
const_shp = None
bcast.append(numpy.all(1 == const_shp))
......
......@@ -214,8 +214,8 @@ def is_positive(v):
logger.debug('is_positive: %s' % str(v))
if v.owner and v.owner.op == tensor.pow:
try:
exponent = tensor.get_constant_value(v.owner.inputs[1])
except tensor.basic.NotConstantError:
exponent = tensor.get_scalar_constant_value(v.owner.inputs[1])
except tensor.basic.NotScalarConstantError:
return False
if 0 == exponent % 2:
return True
......
......@@ -135,8 +135,8 @@ def scan(fn,
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = opt.get_constant_value(n_steps)
except tensor.basic.NotConstantError:
n_fixed_steps = opt.get_scalar_constant_value(n_steps)
except tensor.basic.NotScalarConstantError:
n_fixed_steps = None
# Check n_steps is an int
......
......@@ -335,7 +335,7 @@ def scan(fn,
T_value = int(n_steps)
else:
try:
T_value = opt.get_constant_value(n_steps)
T_value = opt.get_scalar_constant_value(n_steps)
except (TypeError, AttributeError):
T_value = None
......
......@@ -24,7 +24,7 @@ from theano.compile.pfunc import rebuild_collect_shared
from theano import gof
from theano import tensor, scalar
from theano.gof.python25 import all
from theano.tensor.basic import get_constant_value
from theano.tensor.basic import get_scalar_constant_value
# Logging function for sending warning or info
......
......@@ -363,8 +363,8 @@ def scan(fn,
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = opt.get_constant_value(n_steps)
except tensor.basic.NotConstantError:
n_fixed_steps = opt.get_scalar_constant_value(n_steps)
except tensor.basic.NotScalarConstantError:
n_fixed_steps = None
# Check n_steps is an int
......
......@@ -18,7 +18,7 @@ import numpy
import theano
from theano import tensor
from theano.tensor import opt, get_constant_value
from theano.tensor import opt, get_scalar_constant_value
from theano import gof
from theano.gof.python25 import maxsize, any
from theano.gof.opt import Optimizer
......@@ -1164,13 +1164,13 @@ class ScanMerge(gof.Optimizer):
nsteps = node.inputs[0]
try:
nsteps = int(get_constant_value(nsteps))
nsteps = int(get_scalar_constant_value(nsteps))
except TypeError:
pass
rep_nsteps = rep.inputs[0]
try:
rep_nsteps = int(get_constant_value(rep_nsteps))
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
except TypeError:
pass
......
......@@ -463,20 +463,20 @@ def _allclose(a, b, rtol=None, atol=None):
return numpy.allclose(a, b, atol=atol_, rtol=rtol_)
class NotConstantError(Exception):
class NotScalarConstantError(Exception):
"""
Raised by get_constant_value if called on something that is
Raised by get_scalar_constant_value if called on something that is
not constant.
"""
pass
def get_constant_value(v):
def get_scalar_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
this function digs through them.
If `v` is not some view of constant data, then raise a NotConstantError.
If `v` is not some view of constant scalar data, then raise a NotScalarConstantError.
:note: There may be another function similar to this one in the
code, but I'm not sure where it is.
......@@ -496,30 +496,31 @@ def get_constant_value(v):
numpy.complex(data) # works for all numeric scalars
return data
except Exception:
raise NotConstantError(
raise NotScalarConstantError(
'v.data is non-numeric, non-scalar, or has more than one'
' unique value', v)
if v.owner:
if isinstance(v.owner.op, Alloc):
return get_constant_value(v.owner.inputs[0])
return get_scalar_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, DimShuffle):
return get_constant_value(v.owner.inputs[0])
return get_scalar_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, Rebroadcast):
return get_constant_value(v.owner.inputs[0])
return get_scalar_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, Elemwise) and \
isinstance(v.owner.op.scalar_op, scal.Second):
shape, val = v.owner.inputs
return get_constant_value(val)
return get_scalar_constant_value(val)
if isinstance(v.owner.op, scal.Second):
x, y = v.owner.inputs
return get_constant_value(y)
return get_scalar_constant_value(y)
# Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would
# mess with the stabilization optimization.
if (isinstance(v.owner.op, Elemwise) and isinstance(
v.owner.op.scalar_op, scal.Cast)) or \
isinstance(v.owner.op, scal.Cast):
const = get_constant_value(v.owner.inputs[0])
const = get_scalar_constant_value(v.owner.inputs[0])
ret = [[None]]
v.owner.op.perform(v.owner, [const], ret)
return ret[0][0]
......@@ -556,7 +557,7 @@ def get_constant_value(v):
# axis.
ret = v.owner.inputs[0].owner.inputs[
v.owner.op.idx_list[0] + 1]
ret = get_constant_value(ret)
ret = get_scalar_constant_value(ret)
# join can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype)
if (v.owner.inputs[0].owner and
......@@ -569,7 +570,7 @@ def get_constant_value(v):
len(v.owner.op.idx_list) == 1):
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]]
ret = get_constant_value(ret)
ret = get_scalar_constant_value(ret)
# MakeVector can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype)
......@@ -582,7 +583,7 @@ def get_constant_value(v):
v.owner.op.idx_list[0]]:
return numpy.asarray(1)
raise NotConstantError(v)
raise NotScalarConstantError(v)
class TensorType(Type):
......@@ -1825,8 +1826,8 @@ class _tensor_py_operators:
# TO TRUMP NUMPY OPERATORS
__array_priority__ = 1000
def get_constant_value(self):
return get_constant_value(self)
def get_scalar_constant_value(self):
return get_scalar_constant_value(self)
def zeros_like(model):
return zeros_like(model)
......@@ -2354,10 +2355,10 @@ class SpecifyShape(Op):
new_shape = []
for dim in xrange(node.inputs[0].ndim):
try:
s = get_constant_value(node.inputs[1][dim])
s = get_scalar_constant_value(node.inputs[1][dim])
s = as_tensor_variable(s)
new_shape.append(s)
except NotConstantError:
except NotScalarConstantError:
new_shape.append(node.inputs[1][dim])
assert len(new_shape) == len(xshape)
......@@ -2653,9 +2654,9 @@ def max(x, axis=None, keepdims=False):
out = CAReduce(scal.maximum, axis)(x)
else:
try:
const = get_constant_value(axis)
const = get_scalar_constant_value(axis)
out = CAReduce(scal.maximum, list(const))(x)
except NotConstantError:
except NotScalarConstantError:
out = max_and_argmax(x, axis)[0]
if keepdims:
......@@ -3264,8 +3265,8 @@ class Alloc(gof.Op):
(i, s_as_str))
# if s is constant 1, then we're broadcastable in that dim
try:
const_shp = get_constant_value(s)
except NotConstantError:
const_shp = get_scalar_constant_value(s)
except NotScalarConstantError:
const_shp = None
bcast.append(numpy.all(1 == const_shp))
otype = TensorType(dtype=v.dtype, broadcastable=bcast)
......@@ -3804,16 +3805,16 @@ def get_idx_list(inputs, idx_list):
def extract_constant(x):
'''
This function is basically a call to tensor.get_constant_value. The
This function is basically a call to tensor.get_scalar_constant_value. The
main difference is the behaviour in case of failure. While
get_constant_value raises an TypeError, this function returns x,
get_scalar_constant_value raises an TypeError, this function returns x,
as a tensor if possible. If x is a ScalarVariable from a
scalar_from_tensor, we remove the conversion. If x is just a
ScalarVariable, we convert it to a tensor with tensor_from_scalar.
'''
try:
x = get_constant_value(x)
except NotConstantError:
x = get_scalar_constant_value(x)
except NotScalarConstantError:
pass
if (isinstance(x, scal.ScalarVariable) or
isinstance(x, scal.sharedvar.ScalarSharedVariable)):
......@@ -5412,11 +5413,11 @@ class Join(Op):
# Axis can also be a constant
if not isinstance(axis, int):
try:
# Note : `get_constant_value` returns a ndarray not a
# Note : `get_scalar_constant_value` returns a ndarray not a
# int
axis = int(get_constant_value(axis))
axis = int(get_scalar_constant_value(axis))
except NotConstantError:
except NotScalarConstantError:
pass
if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the
......@@ -5783,9 +5784,9 @@ class Reshape(Op):
# Try to see if we can infer that y has a constant value of 1.
# If so, that dimension should be broadcastable.
try:
bcasts[index] = (hasattr(y, 'get_constant_value') and
y.get_constant_value() == 1)
except NotConstantError:
bcasts[index] = (hasattr(y, 'get_scalar_constant_value') and
y.get_scalar_constant_value() == 1)
except NotScalarConstantError:
pass
return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)])
......@@ -5858,10 +5859,10 @@ class Reshape(Op):
for i in xrange(self.ndim):
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0])
try:
os_i = get_constant_value(node.inputs[1][i]).item()
os_i = get_scalar_constant_value(node.inputs[1][i]).item()
if os_i == -1:
os_i = default_os_i
except NotConstantError:
except NotScalarConstantError:
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
......@@ -6141,9 +6142,9 @@ class ARange(Op):
def is_constant_value(var, value):
try:
v = get_constant_value(var)
v = get_scalar_constant_value(var)
return numpy.all(v == value)
except NotConstantError:
except NotScalarConstantError:
pass
return False
......
......@@ -1614,7 +1614,7 @@ def local_gemm_to_ger(node):
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
try:
bval = T.get_constant_value(b)
bval = T.get_scalar_constant_value(b)
except TypeError:
# b isn't a constant, GEMM is doing useful pre-scaling
return
......
......@@ -15,7 +15,7 @@ import logging
import numpy
import theano
from theano.tensor import (as_tensor_variable, blas, get_constant_value,
from theano.tensor import (as_tensor_variable, blas, get_scalar_constant_value,
patternbroadcast)
from theano import OpenMPOp, config
from theano.gof import Apply
......@@ -90,7 +90,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
image_shape = list(image_shape)
for i in xrange(len(image_shape)):
if image_shape[i] is not None:
image_shape[i] = get_constant_value(
image_shape[i] = get_scalar_constant_value(
as_tensor_variable(image_shape[i]))
assert str(image_shape[i].dtype).startswith('int')
image_shape[i] = int(image_shape[i])
......@@ -98,7 +98,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
filter_shape = list(filter_shape)
for i in xrange(len(filter_shape)):
if filter_shape[i] is not None:
filter_shape[i] = get_constant_value(
filter_shape[i] = get_scalar_constant_value(
as_tensor_variable(filter_shape[i]))
assert str(filter_shape[i].dtype).startswith('int')
filter_shape[i] = int(filter_shape[i])
......
......@@ -1409,7 +1409,7 @@ def _check_rows_is_arange_len_labels(rows, labels):
def _is_const(z, val, approx=False):
try:
maybe = opt.get_constant_value(z)
maybe = opt.get_scalar_constant_value(z)
except TypeError:
return False
if approx:
......
......@@ -136,9 +136,9 @@ def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
"""
try:
v = opt.get_constant_value(expr)
v = opt.get_scalar_constant_value(expr)
return numpy.allclose(v, 1)
except tensor.NotConstantError:
except tensor.NotScalarConstantError:
return False
log1msigm_to_softplus = gof.PatternSub(
......@@ -275,7 +275,7 @@ def is_neg(var):
if apply.op == tensor.mul and len(apply.inputs) >= 2:
for idx, mul_input in enumerate(apply.inputs):
try:
constant = opt.get_constant_value(mul_input)
constant = opt.get_scalar_constant_value(mul_input)
is_minus_1 = numpy.allclose(constant, -1)
except TypeError:
is_minus_1 = False
......@@ -647,7 +647,7 @@ def local_1msigmoid(node):
return # graph is using both sigm and 1-sigm
if sub_r.owner and sub_r.owner.op == sigmoid:
try:
val_l = opt.get_constant_value(sub_l)
val_l = opt.get_scalar_constant_value(sub_l)
except Exception, e:
return
if numpy.allclose(numpy.sum(val_l), 1):
......
......@@ -30,10 +30,10 @@ class TestConv2D(utt.InferShapeTester):
verify_grad=True, should_raise=False):
if N_image_shape is None:
N_image_shape = [T.get_constant_value(T.
N_image_shape = [T.get_scalar_constant_value(T.
as_tensor_variable(x)) for x in image_shape]
if N_filter_shape is None:
N_filter_shape = [T.get_constant_value(T.
N_filter_shape = [T.get_scalar_constant_value(T.
as_tensor_variable(x)) for x in filter_shape]
if input is None:
......
......@@ -33,7 +33,7 @@ from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer)
from theano.gof.opt import merge_optimizer
from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value, ShapeError, NotConstantError
from basic import get_scalar_constant_value, ShapeError, NotScalarConstantError
theano.configparser.AddConfigVar('on_shape_error',
......@@ -92,10 +92,10 @@ def scalarconsts_rest(inputs):
nonconsts = []
for i in inputs:
try:
v = get_constant_value(i)
v = get_scalar_constant_value(i)
consts.append(v)
origconsts.append(i)
except NotConstantError:
except NotScalarConstantError:
nonconsts.append(i)
return consts, origconsts, nonconsts
......@@ -322,15 +322,15 @@ def local_0_dot_x(node):
y = node.inputs[1]
replace = False
try:
if get_constant_value(x) == 0:
if get_scalar_constant_value(x) == 0:
replace = True
except NotConstantError:
except NotScalarConstantError:
pass
try:
if get_constant_value(y) == 0:
if get_scalar_constant_value(y) == 0:
replace = True
except NotConstantError:
except NotScalarConstantError:
pass
if replace:
......@@ -1177,9 +1177,9 @@ def local_subtensor_make_vector(node):
elif isinstance(idx, Variable):
# if it is a constant we can do something with it
try:
v = get_constant_value(idx)
v = get_scalar_constant_value(idx)
return [x.owner.inputs[v]]
except NotConstantError:
except NotScalarConstantError:
pass
else:
# it is a slice of ints and/or Variables
......@@ -1315,13 +1315,13 @@ def local_remove_useless_assert(node):
cond = []
for c in node.inputs[1:]:
try:
const = get_constant_value(c)
const = get_scalar_constant_value(c)
if 0 != const.ndim or const == 0:
#Should we raise an error here? How to be sure it
#is not catched?
cond.append(c)
except NotConstantError:
except NotScalarConstantError:
cond.append(c)
if len(cond) == 0:
......@@ -1477,7 +1477,7 @@ def local_upcast_elemwise_constant_inputs(node):
else:
try:
# works only for scalars
cval_i = get_constant_value(i)
cval_i = get_scalar_constant_value(i)
if all(i.broadcastable):
new_inputs.append(T.shape_padleft(
T.cast(cval_i, output_dtype),
......@@ -1490,7 +1490,7 @@ def local_upcast_elemwise_constant_inputs(node):
*[shape_i(d)(i) for d in xrange(i.ndim)]))
#print >> sys.stderr, "AAA",
#*[Shape_i(d)(i) for d in xrange(i.ndim)]
except TypeError:
except NotScalarConstantError:
#for the case of a non-scalar
if isinstance(i, T.TensorConstant):
new_inputs.append(T.cast(i, output_dtype))
......@@ -1550,8 +1550,8 @@ def local_useless_subtensor(node):
length_pos = shape_of[node.inputs[0]][pos]
try:
length_pos_data = get_constant_value(length_pos)
except NotConstantError:
length_pos_data = get_scalar_constant_value(length_pos)
except NotScalarConstantError:
pass
if isinstance(idx.stop, int):
......@@ -2032,9 +2032,9 @@ def local_incsubtensor_of_allocs(node):
y = node.inputs[1]
replace = False
try:
if get_constant_value(y) == 0:
if get_scalar_constant_value(y) == 0:
replace = True
except NotConstantError:
except NotScalarConstantError:
pass
if replace:
......@@ -2059,13 +2059,13 @@ def local_setsubtensor_of_allocs(node):
replace_y = None
try:
replace_x = get_constant_value(x)
except NotConstantError:
replace_x = get_scalar_constant_value(x)
except NotScalarConstantError:
pass
try:
replace_y = get_constant_value(y)
except NotConstantError:
replace_y = get_scalar_constant_value(y)
except NotScalarConstantError:
pass
if (replace_x == replace_y and
......@@ -2253,24 +2253,24 @@ def local_mul_switch_sink(node):
if i.owner and i.owner.op == T.switch:
switch = i.owner
try:
if get_constant_value(switch.inputs[1]) == 0.:
if get_scalar_constant_value(switch.inputs[1]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], 0,
T.mul(*(listmul + [switch.inputs[2]])))]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except NotConstantError:
except NotScalarConstantError:
pass
try:
if get_constant_value(switch.inputs[2]) == 0.:
if get_scalar_constant_value(switch.inputs[2]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0],
T.mul(*(listmul + [switch.inputs[1]])), 0)]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except NotConstantError:
except NotScalarConstantError:
pass
return False
......@@ -2295,22 +2295,22 @@ def local_div_switch_sink(node):
if node.inputs[0].owner and node.inputs[0].owner.op == T.switch:
switch = node.inputs[0].owner
try:
if get_constant_value(switch.inputs[1]) == 0.:
if get_scalar_constant_value(switch.inputs[1]) == 0.:
fct = [T.switch(switch.inputs[0], 0,
op(switch.inputs[2], node.inputs[1]))]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except NotConstantError:
except NotScalarConstantError:
pass
try:
if get_constant_value(switch.inputs[2]) == 0.:
if get_scalar_constant_value(switch.inputs[2]) == 0.:
fct = [T.switch(switch.inputs[0],
op(switch.inputs[1], node.inputs[1]), 0)]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except NotConstantError:
except NotScalarConstantError:
pass
return False
......@@ -2375,7 +2375,7 @@ if 0:
def tmp(thing):
try:
return T.get_constant_value(thing)
return T.get_scalar_constant_value(thing)
except (TypeError, ValueError), e:
print e, thing.owner.inputs[0]
return None
......@@ -2702,8 +2702,8 @@ class Canonizer(gof.LocalOptimizer):
"""
if isinstance(v, Variable):
try:
return get_constant_value(v)
except NotConstantError:
return get_scalar_constant_value(v)
except NotScalarConstantError:
return None
else:
return v
......@@ -3204,15 +3204,15 @@ def local_sum_alloc(node):
if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))):
try:
val = get_constant_value(input)
val = get_scalar_constant_value(input)
assert val.size == 1
val = val.reshape(1)[0] * T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)]
except NotConstantError:
except NotScalarConstantError:
pass
else:
try:
val = get_constant_value(input)
val = get_scalar_constant_value(input)
assert val.size == 1
val = val.reshape(1)[0]
to_prod = [shapes[i] for i in xrange(len(shapes))
......@@ -3222,7 +3222,7 @@ def local_sum_alloc(node):
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])]
except NotConstantError:
except NotScalarConstantError:
pass
......@@ -3282,8 +3282,8 @@ def local_mul_zero(node):
for i in node.inputs:
try:
value = get_constant_value(i)
except NotConstantError:
value = get_scalar_constant_value(i)
except NotScalarConstantError:
continue
#print 'MUL by value', value, node.inputs
if N.all(value == 0):
......@@ -3520,8 +3520,8 @@ def local_add_specialize(node):
new_inputs = []
for input in node.inputs:
try:
y = get_constant_value(input)
except NotConstantError:
y = get_scalar_constant_value(input)
except NotScalarConstantError:
y = input
if numpy.all(y == 0.0):
continue
......@@ -3614,7 +3614,7 @@ def local_abs_merge(node):
if i.owner and i.owner.op == T.abs_:
inputs.append(i.owner.inputs[0])
else:
const = get_constant_value(i)
const = get_scalar_constant_value(i)
if not (const >= 0).all():
return False
inputs.append(i)
......@@ -3880,9 +3880,9 @@ def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
"""
try:
v = get_constant_value(expr)
v = get_scalar_constant_value(expr)
return numpy.allclose(v, 1)
except NotConstantError:
except NotScalarConstantError:
return False
......@@ -3890,9 +3890,9 @@ def _is_minus1(expr):
"""rtype bool. True iff expr is a constant close to -1
"""
try:
v = get_constant_value(expr)
v = get_scalar_constant_value(expr)
return numpy.allclose(v, -1)
except NotConstantError:
except NotScalarConstantError:
return False
#1+erf(x)=>erfc(-x)
......@@ -4132,8 +4132,8 @@ def local_grad_log_erfc_neg(node):
mul_neg = T.mul(*mul_inputs)
try:
cst2 = get_constant_value(mul_neg.owner.inputs[0])
except NotConstantError:
cst2 = get_scalar_constant_value(mul_neg.owner.inputs[0])
except NotScalarConstantError:
return False
if len(mul_neg.owner.inputs) == 2:
......@@ -4159,8 +4159,8 @@ def local_grad_log_erfc_neg(node):
x = erfc_x
try:
cst = get_constant_value(erfc_x.owner.inputs[0])
except NotConstantError:
cst = get_scalar_constant_value(erfc_x.owner.inputs[0])
except NotScalarConstantError:
return False
if cst2 != -cst * 2:
return False
......
......@@ -41,7 +41,7 @@ from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer
from theano.gof import InconsistencyError, toolbox
from basic import get_constant_value
from basic import get_scalar_constant_value
from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal
......@@ -64,7 +64,7 @@ class MaxAndArgmaxOptimizer(Optimizer):
if node.op == T._max_and_argmax:
if len(node.outputs[1].clients)==0:
try:
axis=get_constant_value(node.inputs[1])
axis=get_scalar_constant_value(node.inputs[1])
except (ValueError, TypeError), e:
return False
......
......@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor,
tensor_copy, tensordot, tensordot_grad, TensorType, unbroadcast,
var, Join, shape, MaxAndArgmax, lscalar, zvector, exp,
get_constant_value, ivector, reshape, scalar_from_tensor, scal,
get_scalar_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
......@@ -2140,7 +2140,7 @@ class T_max_and_argmax(unittest.TestCase):
cost = argmax(x, axis=0).sum()
value_error_raised = False
gx = grad(cost, x)
val = tensor.get_constant_value(gx)
val = tensor.get_scalar_constant_value(gx)
assert val == 0.0
def test_grad(self):
......@@ -6167,40 +6167,40 @@ def test_dimshuffle_duplicate():
assert success
class T_get_constant_value(unittest.TestCase):
def test_get_constant_value(self):
class T_get_scalar_constant_value(unittest.TestCase):
def test_get_scalar_constant_value(self):
a = tensor.stack(1, 2, 3)
assert get_constant_value(a[0]) == 1
assert get_constant_value(a[1]) == 2
assert get_constant_value(a[2]) == 3
assert get_scalar_constant_value(a[0]) == 1
assert get_scalar_constant_value(a[1]) == 2
assert get_scalar_constant_value(a[2]) == 3
b = tensor.iscalar()
a = tensor.stack(b, 2, 3)
self.assertRaises(tensor.basic.NotConstantError, get_constant_value, a[0])
assert get_constant_value(a[1]) == 2
assert get_constant_value(a[2]) == 3
self.assertRaises(tensor.basic.NotScalarConstantError, get_scalar_constant_value, a[0])
assert get_scalar_constant_value(a[1]) == 2
assert get_scalar_constant_value(a[2]) == 3
# For now get_constant_value goes through only MakeVector and Join of
# For now get_scalar_constant_value goes through only MakeVector and Join of
# scalars.
v = tensor.ivector()
a = tensor.stack(v, 2, 3)
self.assertRaises(TypeError, get_constant_value, a[0])
self.assertRaises(TypeError, get_constant_value, a[1])
self.assertRaises(TypeError, get_constant_value, a[2])
self.assertRaises(TypeError, get_scalar_constant_value, a[0])
self.assertRaises(TypeError, get_scalar_constant_value, a[1])
self.assertRaises(TypeError, get_scalar_constant_value, a[2])
# Test the case SubTensor(Shape(v)) when the dimensions
# is broadcastable.
v = tensor.row()
assert get_constant_value(v.shape[0]) == 1
assert get_scalar_constant_value(v.shape[0]) == 1
def test_subtensor_of_constant(self):
c = constant(rand(5))
for i in range(c.value.shape[0]):
assert get_constant_value(c[i]) == c.value[i]
assert get_scalar_constant_value(c[i]) == c.value[i]
c = constant(rand(5, 5))
for i in range(c.value.shape[0]):
for j in range(c.value.shape[1]):
assert get_constant_value(c[i, j]) == c.value[i, j]
assert get_scalar_constant_value(c[i, j]) == c.value[i, j]
class T_as_tensor_variable(unittest.TestCase):
......
......@@ -856,7 +856,7 @@ def test_gt_grad():
"""A user test that failed.
Something about it made Elemwise.grad return something that was
too complicated for get_constant_value to recognize as being 0, so
too complicated for get_scalar_constant_value to recognize as being 0, so
gradient.grad reported that it was not a valid gradient of an
integer.
......
......@@ -936,7 +936,7 @@ class T_fibby(unittest.TestCase):
if node.op == fibby:
x = node.inputs[0]
try:
if numpy.all(0 == get_constant_value(x)):
if numpy.all(0 == get_scalar_constant_value(x)):
return [x]
except TypeError:
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论