提交 483cb9d3 authored 作者: lamblin's avatar lamblin

Merge pull request #1161 from goodfeli/rebase

Ready to merge: get rid of dangerous "TypeError = not constant" mechanism
......@@ -163,7 +163,7 @@ def dot(l, r):
return rval
def get_constant_value(v):
def get_scalar_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
......@@ -171,12 +171,13 @@ def get_constant_value(v):
If theano.sparse is also there, we will look over CSM op.
If `v` is not some view of constant data, then raise a TypeError.
If `v` is not some view of constant data, then raise a
tensor.basic.NotScalarConstantError.
"""
if hasattr(theano, 'sparse') and isinstance(v.type,
theano.sparse.SparseType):
if v.owner is not None and isinstance(v.owner.op,
theano.sparse.CSM):
data = v.owner.inputs[0]
return tensor.get_constant_value(data)
return tensor.get_constant_value(v)
return tensor.get_scalar_constant_value(data)
return tensor.get_scalar_constant_value(v)
......@@ -975,7 +975,7 @@ def _populate_grad_dict(var_to_app_to_idx,
msg += "%s."
msg % (str(node.op), str(term), str(type(term)),
i, str(theano.get_constant_value(term)))
i, str(theano.get_scalar_constant_value(term)))
raise ValueError(msg)
......@@ -1616,9 +1616,9 @@ def _is_zero(x):
no_constant_value = True
try:
constant_value = theano.get_constant_value(x)
constant_value = theano.get_scalar_constant_value(x)
no_constant_value = False
except TypeError:
except theano.tensor.basic.NotScalarConstantError:
pass
if no_constant_value:
......
......@@ -2691,8 +2691,8 @@ class GpuAlloc(GpuOp):
raise TypeError('Shape arguments must be integers', s)
# if s is constant 1, then we're broadcastable in that dim
try:
const_shp = tensor.get_constant_value(s)
except TypeError:
const_shp = tensor.get_scalar_constant_value(s)
except tensor.NotScalarConstantError:
const_shp = None
bcast.append(numpy.all(1 == const_shp))
otype = CudaNdarrayType(dtype='float32', broadcastable=bcast)
......
......@@ -214,8 +214,8 @@ def is_positive(v):
logger.debug('is_positive: %s' % str(v))
if v.owner and v.owner.op == tensor.pow:
try:
exponent = tensor.get_constant_value(v.owner.inputs[1])
except TypeError:
exponent = tensor.get_scalar_constant_value(v.owner.inputs[1])
except tensor.basic.NotScalarConstantError:
return False
if 0 == exponent % 2:
return True
......
......@@ -135,8 +135,8 @@ def scan(fn,
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = opt.get_constant_value(n_steps)
except (TypeError, AttributeError):
n_fixed_steps = opt.get_scalar_constant_value(n_steps)
except tensor.basic.NotScalarConstantError:
n_fixed_steps = None
# Check n_steps is an int
......
......@@ -335,7 +335,7 @@ def scan(fn,
T_value = int(n_steps)
else:
try:
T_value = opt.get_constant_value(n_steps)
T_value = opt.get_scalar_constant_value(n_steps)
except (TypeError, AttributeError):
T_value = None
......
......@@ -24,7 +24,7 @@ from theano.compile.pfunc import rebuild_collect_shared
from theano import gof
from theano import tensor, scalar
from theano.gof.python25 import all
from theano.tensor.basic import get_constant_value
from theano.tensor.basic import get_scalar_constant_value
# Logging function for sending warning or info
......
......@@ -363,8 +363,8 @@ def scan(fn,
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = opt.get_constant_value(n_steps)
except (TypeError, AttributeError):
n_fixed_steps = opt.get_scalar_constant_value(n_steps)
except tensor.basic.NotScalarConstantError:
n_fixed_steps = None
# Check n_steps is an int
......
......@@ -18,7 +18,7 @@ import numpy
import theano
from theano import tensor
from theano.tensor import opt, get_constant_value
from theano.tensor import opt, get_scalar_constant_value
from theano import gof
from theano.gof.python25 import maxsize, any
from theano.gof.opt import Optimizer
......@@ -1164,14 +1164,14 @@ class ScanMerge(gof.Optimizer):
nsteps = node.inputs[0]
try:
nsteps = int(get_constant_value(nsteps))
except TypeError:
nsteps = int(get_scalar_constant_value(nsteps))
except tensor.NotScalarConstantError:
pass
rep_nsteps = rep.inputs[0]
try:
rep_nsteps = int(get_constant_value(rep_nsteps))
except TypeError:
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
except tensor.NotScalarConstantError:
pass
# Check to see if it is an input of a different node
......
......@@ -25,7 +25,7 @@ from theano.compile.pfunc import rebuild_collect_shared
from theano import gof
from theano import tensor, scalar
from theano.gof.python25 import all, OrderedDict
from theano.tensor.basic import get_constant_value
from theano.tensor.basic import get_scalar_constant_value
################ Utility Functions and Classes #######################
......@@ -308,7 +308,7 @@ def isNaN_or_Inf_or_None(x):
isStr = False
if not isNaN and not isInf:
try:
val = get_constant_value(x)
val = get_scalar_constant_value(x)
isInf = numpy.isinf(val)
isNaN = numpy.isnan(val)
except Exception:
......
......@@ -463,70 +463,88 @@ def _allclose(a, b, rtol=None, atol=None):
return numpy.allclose(a, b, atol=atol_, rtol=rtol_)
class NotConstantError(TypeError):
class NotScalarConstantError(Exception):
"""
Raised by get_constant_value if called on something that is
not constant.
For now it is a TypeError, to maintain the old interface
that get_constant_value should raise a TypeError in this
situation. However, this is unsafe because get_constant_value
could inadvertently raise a TypeError if it has a bug.
So we should eventually make NotConstantError derive
from Exception directly, and modify all code that uses
get_constant_value to catch this more specific exception.
Raised by get_scalar_constant_value if called on something that is
not a scalar constant.
"""
class EmptyConstantError(NotScalarConstantError):
"""
Raised by get_scalar_const_value if called on something that is a
zero dimensional constant.
"""
pass
def get_constant_value(v):
def get_scalar_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
this function digs through them.
If `v` is not some view of constant data, then raise a NotConstantError.
If `v` is not some view of constant scalar data, then raise a NotScalarConstantError.
:note: There may be another function similar to this one in the
code, but I'm not sure where it is.
"""
if isinstance(v, Constant):
if getattr(v.tag, 'unique_value', None) is not None:
data = v.tag.unique_value
else:
data = v.data
if v is None:
# None is not a scalar (and many uses of this function seem to depend
# on passing it None)
raise NotScalarConstantError()
if isinstance(v, (int, float)):
return numpy.asarray(v)
def numpy_scalar(n):
""" Return a scalar stored in a numpy ndarray, or raise
NotScalarConstantError if the numpy ndarray is not a scalar
"""
# handle case where data is numpy.array([])
if hasattr(data, 'shape') and len(data.shape) == 0 or \
__builtins__['max'](data.shape) == 0:
if data.ndim > 0 and (len(data.shape) == 0 or
__builtins__['max'](data.shape) == 0):
assert numpy.all(numpy.array([]) == data)
return data
raise EmptyConstantError()
try:
numpy.complex(data) # works for all numeric scalars
return data
except Exception:
raise NotConstantError(
raise NotScalarConstantError(
'v.data is non-numeric, non-scalar, or has more than one'
' unique value', v)
' unique value', n)
if isinstance(v, numpy.ndarray):
return numpy_scalar(v)
if isinstance(v, Constant):
if getattr(v.tag, 'unique_value', None) is not None:
data = v.tag.unique_value
else:
data = v.data
return numpy_scalar(data)
if v.owner:
if isinstance(v.owner.op, Alloc):
return get_constant_value(v.owner.inputs[0])
return get_scalar_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, DimShuffle):
return get_constant_value(v.owner.inputs[0])
return get_scalar_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, Rebroadcast):
return get_constant_value(v.owner.inputs[0])
return get_scalar_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, Elemwise) and \
isinstance(v.owner.op.scalar_op, scal.Second):
shape, val = v.owner.inputs
return get_constant_value(val)
return get_scalar_constant_value(val)
if isinstance(v.owner.op, scal.Second):
x, y = v.owner.inputs
return get_constant_value(y)
return get_scalar_constant_value(y)
# Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would
# mess with the stabilization optimization.
if (isinstance(v.owner.op, Elemwise) and isinstance(
v.owner.op.scalar_op, scal.Cast)) or \
isinstance(v.owner.op, scal.Cast):
const = get_constant_value(v.owner.inputs[0])
const = get_scalar_constant_value(v.owner.inputs[0])
ret = [[None]]
v.owner.op.perform(v.owner, [const], ret)
return ret[0][0]
......@@ -563,7 +581,7 @@ def get_constant_value(v):
# axis.
ret = v.owner.inputs[0].owner.inputs[
v.owner.op.idx_list[0] + 1]
ret = get_constant_value(ret)
ret = get_scalar_constant_value(ret)
# join can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype)
if (v.owner.inputs[0].owner and
......@@ -576,7 +594,7 @@ def get_constant_value(v):
len(v.owner.op.idx_list) == 1):
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]]
ret = get_constant_value(ret)
ret = get_scalar_constant_value(ret)
# MakeVector can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype)
......@@ -589,7 +607,7 @@ def get_constant_value(v):
v.owner.op.idx_list[0]]:
return numpy.asarray(1)
raise NotConstantError(v)
raise NotScalarConstantError(v)
class TensorType(Type):
......@@ -1832,8 +1850,8 @@ class _tensor_py_operators:
# TO TRUMP NUMPY OPERATORS
__array_priority__ = 1000
def get_constant_value(self):
return get_constant_value(self)
def get_scalar_constant_value(self):
return get_scalar_constant_value(self)
def zeros_like(model):
return zeros_like(model)
......@@ -2361,10 +2379,10 @@ class SpecifyShape(Op):
new_shape = []
for dim in xrange(node.inputs[0].ndim):
try:
s = get_constant_value(node.inputs[1][dim])
s = get_scalar_constant_value(node.inputs[1][dim])
s = as_tensor_variable(s)
new_shape.append(s)
except TypeError:
except NotScalarConstantError:
new_shape.append(node.inputs[1][dim])
assert len(new_shape) == len(xshape)
......@@ -2656,14 +2674,22 @@ def max(x, axis=None, keepdims=False):
:note: we return an error as numpy when we reduce a dim with a shape of 0
"""
if isinstance(axis, (list, tuple)) and len(axis) > 1:
# We have a choice of implementing this call with the
# CAReduce op or the MaxAndArgmax op.
# MaxAndArgmax supports grad and Rop, so we prefer to use that.
# CAReduce is faster, but optimizations will replace MaxAndArgmax[0]
# with CAReduce at compile time, so at this stage the important
# thing is supporting all user interface features, not speed.
# Some cases can be implemented only with CAReduce.
# We thus prefer to use MaxAndArgmax, if possible. It does not
# support all axis arguments, so we may need to fall back to CAReduce.
try:
out = max_and_argmax(x, axis)[0]
except Exception:
out = CAReduce(scal.maximum, axis)(x)
else:
try:
const = get_constant_value(axis)
out = CAReduce(scal.maximum, list(const))(x)
except Exception:
out = max_and_argmax(x, axis)[0]
if keepdims:
out = makeKeepDims(x, out, axis)
......@@ -3271,8 +3297,8 @@ class Alloc(gof.Op):
(i, s_as_str))
# if s is constant 1, then we're broadcastable in that dim
try:
const_shp = get_constant_value(s)
except TypeError:
const_shp = get_scalar_constant_value(s)
except NotScalarConstantError:
const_shp = None
bcast.append(numpy.all(1 == const_shp))
otype = TensorType(dtype=v.dtype, broadcastable=bcast)
......@@ -3811,16 +3837,16 @@ def get_idx_list(inputs, idx_list):
def extract_constant(x):
'''
This function is basically a call to tensor.get_constant_value. The
This function is basically a call to tensor.get_scalar_constant_value. The
main difference is the behaviour in case of failure. While
get_constant_value raises an TypeError, this function returns x,
get_scalar_constant_value raises an TypeError, this function returns x,
as a tensor if possible. If x is a ScalarVariable from a
scalar_from_tensor, we remove the conversion. If x is just a
ScalarVariable, we convert it to a tensor with tensor_from_scalar.
'''
try:
x = get_constant_value(x)
except Exception:
x = get_scalar_constant_value(x)
except NotScalarConstantError:
pass
if (isinstance(x, scal.ScalarVariable) or
isinstance(x, scal.sharedvar.ScalarSharedVariable)):
......@@ -5419,11 +5445,11 @@ class Join(Op):
# Axis can also be a constant
if not isinstance(axis, int):
try:
# Note : `get_constant_value` returns a ndarray not a
# Note : `get_scalar_constant_value` returns a ndarray not a
# int
axis = int(get_constant_value(axis))
axis = int(get_scalar_constant_value(axis))
except TypeError:
except NotScalarConstantError:
pass
if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the
......@@ -5790,9 +5816,9 @@ class Reshape(Op):
# Try to see if we can infer that y has a constant value of 1.
# If so, that dimension should be broadcastable.
try:
bcasts[index] = (hasattr(y, 'get_constant_value') and
y.get_constant_value() == 1)
except TypeError:
bcasts[index] = (hasattr(y, 'get_scalar_constant_value') and
y.get_scalar_constant_value() == 1)
except NotScalarConstantError:
pass
return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)])
......@@ -5865,10 +5891,10 @@ class Reshape(Op):
for i in xrange(self.ndim):
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0])
try:
os_i = get_constant_value(node.inputs[1][i]).item()
os_i = get_scalar_constant_value(node.inputs[1][i]).item()
if os_i == -1:
os_i = default_os_i
except TypeError:
except NotScalarConstantError:
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
......@@ -6148,9 +6174,9 @@ class ARange(Op):
def is_constant_value(var, value):
try:
v = get_constant_value(var)
v = get_scalar_constant_value(var)
return numpy.all(v == value)
except Exception:
except NotScalarConstantError:
pass
return False
......
......@@ -1614,7 +1614,7 @@ def local_gemm_to_ger(node):
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
try:
bval = T.get_constant_value(b)
bval = T.get_scalar_constant_value(b)
except TypeError:
# b isn't a constant, GEMM is doing useful pre-scaling
return
......
......@@ -262,15 +262,15 @@ class RepeatOp(theano.Op):
broadcastable=[False]
else:
try:
const_reps = basic.get_constant_value(repeats)
except basic.NotConstantError:
const_reps = basic.get_scalar_constant_value(repeats)
except basic.NotScalarConstantError:
const_reps = None
if const_reps == 1:
broadcastable = x.broadcastable
else:
broadcastable = list(x.broadcastable)
broadcastable[self.axis] = False
out_type = theano.tensor.TensorType(x.dtype, broadcastable)
return theano.Apply(self, [x, repeats], [out_type()])
......
......@@ -15,7 +15,7 @@ import logging
import numpy
import theano
from theano.tensor import (as_tensor_variable, blas, get_constant_value,
from theano.tensor import (as_tensor_variable, blas, get_scalar_constant_value,
patternbroadcast)
from theano import OpenMPOp, config
from theano.gof import Apply
......@@ -90,7 +90,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
image_shape = list(image_shape)
for i in xrange(len(image_shape)):
if image_shape[i] is not None:
image_shape[i] = get_constant_value(
image_shape[i] = get_scalar_constant_value(
as_tensor_variable(image_shape[i]))
assert str(image_shape[i].dtype).startswith('int')
image_shape[i] = int(image_shape[i])
......@@ -98,7 +98,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
filter_shape = list(filter_shape)
for i in xrange(len(filter_shape)):
if filter_shape[i] is not None:
filter_shape[i] = get_constant_value(
filter_shape[i] = get_scalar_constant_value(
as_tensor_variable(filter_shape[i]))
assert str(filter_shape[i].dtype).startswith('int')
filter_shape[i] = int(filter_shape[i])
......
......@@ -1409,8 +1409,8 @@ def _check_rows_is_arange_len_labels(rows, labels):
def _is_const(z, val, approx=False):
try:
maybe = opt.get_constant_value(z)
except TypeError:
maybe = opt.get_scalar_constant_value(z)
except tensor.NotScalarConstantError:
return False
if approx:
return numpy.allclose(maybe, val)
......
......@@ -14,7 +14,7 @@ from theano.compile import optdb
from theano.configparser import AddConfigVar, BoolParam
from theano.printing import pprint, debugprint
from theano.tensor import basic as tensor
from theano.tensor import elemwise, opt
from theano.tensor import elemwise, opt, NotScalarConstantError
############
......@@ -136,9 +136,9 @@ def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
"""
try:
v = opt.get_constant_value(expr)
v = opt.get_scalar_constant_value(expr)
return numpy.allclose(v, 1)
except TypeError:
except tensor.NotScalarConstantError:
return False
log1msigm_to_softplus = gof.PatternSub(
......@@ -275,9 +275,9 @@ def is_neg(var):
if apply.op == tensor.mul and len(apply.inputs) >= 2:
for idx, mul_input in enumerate(apply.inputs):
try:
constant = opt.get_constant_value(mul_input)
constant = opt.get_scalar_constant_value(mul_input)
is_minus_1 = numpy.allclose(constant, -1)
except TypeError:
except NotScalarConstantError:
is_minus_1 = False
if is_minus_1:
# Found a multiplication by -1.
......@@ -647,7 +647,7 @@ def local_1msigmoid(node):
return # graph is using both sigm and 1-sigm
if sub_r.owner and sub_r.owner.op == sigmoid:
try:
val_l = opt.get_constant_value(sub_l)
val_l = opt.get_scalar_constant_value(sub_l)
except Exception, e:
return
if numpy.allclose(numpy.sum(val_l), 1):
......
......@@ -30,10 +30,10 @@ class TestConv2D(utt.InferShapeTester):
verify_grad=True, should_raise=False):
if N_image_shape is None:
N_image_shape = [T.get_constant_value(T.
N_image_shape = [T.get_scalar_constant_value(T.
as_tensor_variable(x)) for x in image_shape]
if N_filter_shape is None:
N_filter_shape = [T.get_constant_value(T.
N_filter_shape = [T.get_scalar_constant_value(T.
as_tensor_variable(x)) for x in filter_shape]
if input is None:
......
......@@ -33,7 +33,7 @@ from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer)
from theano.gof.opt import merge_optimizer
from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value, ShapeError
from basic import get_scalar_constant_value, ShapeError, NotScalarConstantError
theano.configparser.AddConfigVar('on_shape_error',
......@@ -92,10 +92,10 @@ def scalarconsts_rest(inputs):
nonconsts = []
for i in inputs:
try:
v = get_constant_value(i)
v = get_scalar_constant_value(i)
consts.append(v)
origconsts.append(i)
except Exception:
except NotScalarConstantError:
nonconsts.append(i)
return consts, origconsts, nonconsts
......@@ -125,7 +125,13 @@ def broadcast_like(value, template, fgraph, dtype=None):
if rval.broadcastable[i]
and not template.broadcastable[i]])
assert rval.type.dtype == dtype
assert rval.type.broadcastable == template.broadcastable
if rval.type.broadcastable != template.broadcastable:
raise AssertionError("rval.type.broadcastable is " +
str(rval.type.broadcastable) +
" but template.broadcastable is" +
str(template.broadcastable))
return rval
......@@ -322,15 +328,15 @@ def local_0_dot_x(node):
y = node.inputs[1]
replace = False
try:
if get_constant_value(x) == 0:
if get_scalar_constant_value(x) == 0:
replace = True
except TypeError:
except NotScalarConstantError:
pass
try:
if get_constant_value(y) == 0:
if get_scalar_constant_value(y) == 0:
replace = True
except TypeError:
except NotScalarConstantError:
pass
if replace:
......@@ -1177,9 +1183,9 @@ def local_subtensor_make_vector(node):
elif isinstance(idx, Variable):
# if it is a constant we can do something with it
try:
v = get_constant_value(idx)
v = get_scalar_constant_value(idx)
return [x.owner.inputs[v]]
except Exception:
except NotScalarConstantError:
pass
else:
# it is a slice of ints and/or Variables
......@@ -1315,13 +1321,13 @@ def local_remove_useless_assert(node):
cond = []
for c in node.inputs[1:]:
try:
const = get_constant_value(c)
const = get_scalar_constant_value(c)
if 0 != const.ndim or const == 0:
#Should we raise an error here? How to be sure it
#is not catched?
cond.append(c)
except TypeError:
except NotScalarConstantError:
cond.append(c)
if len(cond) == 0:
......@@ -1477,7 +1483,7 @@ def local_upcast_elemwise_constant_inputs(node):
else:
try:
# works only for scalars
cval_i = get_constant_value(i)
cval_i = get_scalar_constant_value(i)
if all(i.broadcastable):
new_inputs.append(T.shape_padleft(
T.cast(cval_i, output_dtype),
......@@ -1490,7 +1496,7 @@ def local_upcast_elemwise_constant_inputs(node):
*[shape_i(d)(i) for d in xrange(i.ndim)]))
#print >> sys.stderr, "AAA",
#*[Shape_i(d)(i) for d in xrange(i.ndim)]
except TypeError:
except NotScalarConstantError:
#for the case of a non-scalar
if isinstance(i, T.TensorConstant):
new_inputs.append(T.cast(i, output_dtype))
......@@ -1550,8 +1556,8 @@ def local_useless_subtensor(node):
length_pos = shape_of[node.inputs[0]][pos]
try:
length_pos_data = get_constant_value(length_pos)
except TypeError:
length_pos_data = get_scalar_constant_value(length_pos)
except NotScalarConstantError:
pass
if isinstance(idx.stop, int):
......@@ -2032,9 +2038,9 @@ def local_incsubtensor_of_allocs(node):
y = node.inputs[1]
replace = False
try:
if get_constant_value(y) == 0:
if get_scalar_constant_value(y) == 0:
replace = True
except TypeError:
except NotScalarConstantError:
pass
if replace:
......@@ -2059,13 +2065,13 @@ def local_setsubtensor_of_allocs(node):
replace_y = None
try:
replace_x = get_constant_value(x)
except TypeError:
replace_x = get_scalar_constant_value(x)
except NotScalarConstantError:
pass
try:
replace_y = get_constant_value(y)
except TypeError:
replace_y = get_scalar_constant_value(y)
except NotScalarConstantError:
pass
if (replace_x == replace_y and
......@@ -2253,24 +2259,24 @@ def local_mul_switch_sink(node):
if i.owner and i.owner.op == T.switch:
switch = i.owner
try:
if get_constant_value(switch.inputs[1]) == 0.:
if get_scalar_constant_value(switch.inputs[1]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], 0,
T.mul(*(listmul + [switch.inputs[2]])))]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except TypeError:
except NotScalarConstantError:
pass
try:
if get_constant_value(switch.inputs[2]) == 0.:
if get_scalar_constant_value(switch.inputs[2]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0],
T.mul(*(listmul + [switch.inputs[1]])), 0)]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except TypeError:
except NotScalarConstantError:
pass
return False
......@@ -2295,22 +2301,22 @@ def local_div_switch_sink(node):
if node.inputs[0].owner and node.inputs[0].owner.op == T.switch:
switch = node.inputs[0].owner
try:
if get_constant_value(switch.inputs[1]) == 0.:
if get_scalar_constant_value(switch.inputs[1]) == 0.:
fct = [T.switch(switch.inputs[0], 0,
op(switch.inputs[2], node.inputs[1]))]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except TypeError:
except NotScalarConstantError:
pass
try:
if get_constant_value(switch.inputs[2]) == 0.:
if get_scalar_constant_value(switch.inputs[2]) == 0.:
fct = [T.switch(switch.inputs[0],
op(switch.inputs[1], node.inputs[1]), 0)]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except TypeError:
except NotScalarConstantError:
pass
return False
......@@ -2375,7 +2381,7 @@ if 0:
def tmp(thing):
try:
return T.get_constant_value(thing)
return T.get_scalar_constant_value(thing)
except (TypeError, ValueError), e:
print e, thing.owner.inputs[0]
return None
......@@ -2702,8 +2708,8 @@ class Canonizer(gof.LocalOptimizer):
"""
if isinstance(v, Variable):
try:
return get_constant_value(v)
except TypeError:
return get_scalar_constant_value(v)
except NotScalarConstantError:
return None
else:
return v
......@@ -3204,15 +3210,15 @@ def local_sum_alloc(node):
if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))):
try:
val = get_constant_value(input)
val = get_scalar_constant_value(input)
assert val.size == 1
val = val.reshape(1)[0] * T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)]
except TypeError:
except NotScalarConstantError:
pass
else:
try:
val = get_constant_value(input)
val = get_scalar_constant_value(input)
assert val.size == 1
val = val.reshape(1)[0]
to_prod = [shapes[i] for i in xrange(len(shapes))
......@@ -3222,7 +3228,7 @@ def local_sum_alloc(node):
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])]
except TypeError:
except NotScalarConstantError:
pass
......@@ -3282,8 +3288,8 @@ def local_mul_zero(node):
for i in node.inputs:
try:
value = get_constant_value(i)
except TypeError:
value = get_scalar_constant_value(i)
except NotScalarConstantError:
continue
#print 'MUL by value', value, node.inputs
if N.all(value == 0):
......@@ -3520,8 +3526,8 @@ def local_add_specialize(node):
new_inputs = []
for input in node.inputs:
try:
y = get_constant_value(input)
except TypeError:
y = get_scalar_constant_value(input)
except NotScalarConstantError:
y = input
if numpy.all(y == 0.0):
continue
......@@ -3614,7 +3620,7 @@ def local_abs_merge(node):
if i.owner and i.owner.op == T.abs_:
inputs.append(i.owner.inputs[0])
else:
const = get_constant_value(i)
const = get_scalar_constant_value(i)
if not (const >= 0).all():
return False
inputs.append(i)
......@@ -3880,9 +3886,9 @@ def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
"""
try:
v = get_constant_value(expr)
v = get_scalar_constant_value(expr)
return numpy.allclose(v, 1)
except TypeError:
except NotScalarConstantError:
return False
......@@ -3890,9 +3896,9 @@ def _is_minus1(expr):
"""rtype bool. True iff expr is a constant close to -1
"""
try:
v = get_constant_value(expr)
v = get_scalar_constant_value(expr)
return numpy.allclose(v, -1)
except TypeError:
except NotScalarConstantError:
return False
#1+erf(x)=>erfc(-x)
......@@ -4132,8 +4138,8 @@ def local_grad_log_erfc_neg(node):
mul_neg = T.mul(*mul_inputs)
try:
cst2 = get_constant_value(mul_neg.owner.inputs[0])
except TypeError:
cst2 = get_scalar_constant_value(mul_neg.owner.inputs[0])
except NotScalarConstantError:
return False
if len(mul_neg.owner.inputs) == 2:
......@@ -4159,8 +4165,8 @@ def local_grad_log_erfc_neg(node):
x = erfc_x
try:
cst = get_constant_value(erfc_x.owner.inputs[0])
except TypeError:
cst = get_scalar_constant_value(erfc_x.owner.inputs[0])
except NotScalarConstantError:
return False
if cst2 != -cst * 2:
return False
......
......@@ -41,7 +41,7 @@ from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer
from theano.gof import InconsistencyError, toolbox
from basic import get_constant_value
from basic import get_scalar_constant_value, NotScalarConstantError
from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal
......@@ -64,8 +64,8 @@ class MaxAndArgmaxOptimizer(Optimizer):
if node.op == T._max_and_argmax:
if len(node.outputs[1].clients)==0:
try:
axis=get_constant_value(node.inputs[1])
except (ValueError, TypeError), e:
axis=get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError:
return False
new = CAReduce(scal.maximum,axis)(node.inputs[0])
......
......@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor,
tensor_copy, tensordot, tensordot_grad, TensorType, unbroadcast,
var, Join, shape, MaxAndArgmax, lscalar, zvector, exp,
get_constant_value, ivector, reshape, scalar_from_tensor, scal,
get_scalar_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
......@@ -2140,7 +2140,7 @@ class T_max_and_argmax(unittest.TestCase):
cost = argmax(x, axis=0).sum()
value_error_raised = False
gx = grad(cost, x)
val = tensor.get_constant_value(gx)
val = tensor.get_scalar_constant_value(gx)
assert val == 0.0
def test_grad(self):
......@@ -6167,40 +6167,40 @@ def test_dimshuffle_duplicate():
assert success
class T_get_constant_value(unittest.TestCase):
def test_get_constant_value(self):
class T_get_scalar_constant_value(unittest.TestCase):
def test_get_scalar_constant_value(self):
a = tensor.stack(1, 2, 3)
assert get_constant_value(a[0]) == 1
assert get_constant_value(a[1]) == 2
assert get_constant_value(a[2]) == 3
assert get_scalar_constant_value(a[0]) == 1
assert get_scalar_constant_value(a[1]) == 2
assert get_scalar_constant_value(a[2]) == 3
b = tensor.iscalar()
a = tensor.stack(b, 2, 3)
self.assertRaises(TypeError, get_constant_value, a[0])
assert get_constant_value(a[1]) == 2
assert get_constant_value(a[2]) == 3
self.assertRaises(tensor.basic.NotScalarConstantError, get_scalar_constant_value, a[0])
assert get_scalar_constant_value(a[1]) == 2
assert get_scalar_constant_value(a[2]) == 3
# For now get_constant_value goes through only MakeVector and Join of
# For now get_scalar_constant_value goes through only MakeVector and Join of
# scalars.
v = tensor.ivector()
a = tensor.stack(v, 2, 3)
self.assertRaises(TypeError, get_constant_value, a[0])
self.assertRaises(TypeError, get_constant_value, a[1])
self.assertRaises(TypeError, get_constant_value, a[2])
self.assertRaises(tensor.NotScalarConstantError, get_scalar_constant_value, a[0])
self.assertRaises(tensor.NotScalarConstantError, get_scalar_constant_value, a[1])
self.assertRaises(tensor.NotScalarConstantError, get_scalar_constant_value, a[2])
# Test the case SubTensor(Shape(v)) when the dimensions
# is broadcastable.
v = tensor.row()
assert get_constant_value(v.shape[0]) == 1
assert get_scalar_constant_value(v.shape[0]) == 1
def test_subtensor_of_constant(self):
c = constant(rand(5))
for i in range(c.value.shape[0]):
assert get_constant_value(c[i]) == c.value[i]
assert get_scalar_constant_value(c[i]) == c.value[i]
c = constant(rand(5, 5))
for i in range(c.value.shape[0]):
for j in range(c.value.shape[1]):
assert get_constant_value(c[i, j]) == c.value[i, j]
assert get_scalar_constant_value(c[i, j]) == c.value[i, j]
class T_as_tensor_variable(unittest.TestCase):
......
......@@ -856,7 +856,7 @@ def test_gt_grad():
"""A user test that failed.
Something about it made Elemwise.grad return something that was
too complicated for get_constant_value to recognize as being 0, so
too complicated for get_scalar_constant_value to recognize as being 0, so
gradient.grad reported that it was not a valid gradient of an
integer.
......
......@@ -936,7 +936,7 @@ class T_fibby(unittest.TestCase):
if node.op == fibby:
x = node.inputs[0]
try:
if numpy.all(0 == get_constant_value(x)):
if numpy.all(0 == get_scalar_constant_value(x)):
return [x]
except TypeError:
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论