提交 483cb9d3 authored 作者: lamblin's avatar lamblin

Merge pull request #1161 from goodfeli/rebase

Ready to merge: get rid of dangerous "TypeError = not constant" mechanism
...@@ -163,7 +163,7 @@ def dot(l, r): ...@@ -163,7 +163,7 @@ def dot(l, r):
return rval return rval
def get_constant_value(v): def get_scalar_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
...@@ -171,12 +171,13 @@ def get_constant_value(v): ...@@ -171,12 +171,13 @@ def get_constant_value(v):
If theano.sparse is also there, we will look over CSM op. If theano.sparse is also there, we will look over CSM op.
If `v` is not some view of constant data, then raise a TypeError. If `v` is not some view of constant data, then raise a
tensor.basic.NotScalarConstantError.
""" """
if hasattr(theano, 'sparse') and isinstance(v.type, if hasattr(theano, 'sparse') and isinstance(v.type,
theano.sparse.SparseType): theano.sparse.SparseType):
if v.owner is not None and isinstance(v.owner.op, if v.owner is not None and isinstance(v.owner.op,
theano.sparse.CSM): theano.sparse.CSM):
data = v.owner.inputs[0] data = v.owner.inputs[0]
return tensor.get_constant_value(data) return tensor.get_scalar_constant_value(data)
return tensor.get_constant_value(v) return tensor.get_scalar_constant_value(v)
...@@ -975,7 +975,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -975,7 +975,7 @@ def _populate_grad_dict(var_to_app_to_idx,
msg += "%s." msg += "%s."
msg % (str(node.op), str(term), str(type(term)), msg % (str(node.op), str(term), str(type(term)),
i, str(theano.get_constant_value(term))) i, str(theano.get_scalar_constant_value(term)))
raise ValueError(msg) raise ValueError(msg)
...@@ -1616,9 +1616,9 @@ def _is_zero(x): ...@@ -1616,9 +1616,9 @@ def _is_zero(x):
no_constant_value = True no_constant_value = True
try: try:
constant_value = theano.get_constant_value(x) constant_value = theano.get_scalar_constant_value(x)
no_constant_value = False no_constant_value = False
except TypeError: except theano.tensor.basic.NotScalarConstantError:
pass pass
if no_constant_value: if no_constant_value:
......
...@@ -2691,8 +2691,8 @@ class GpuAlloc(GpuOp): ...@@ -2691,8 +2691,8 @@ class GpuAlloc(GpuOp):
raise TypeError('Shape arguments must be integers', s) raise TypeError('Shape arguments must be integers', s)
# if s is constant 1, then we're broadcastable in that dim # if s is constant 1, then we're broadcastable in that dim
try: try:
const_shp = tensor.get_constant_value(s) const_shp = tensor.get_scalar_constant_value(s)
except TypeError: except tensor.NotScalarConstantError:
const_shp = None const_shp = None
bcast.append(numpy.all(1 == const_shp)) bcast.append(numpy.all(1 == const_shp))
otype = CudaNdarrayType(dtype='float32', broadcastable=bcast) otype = CudaNdarrayType(dtype='float32', broadcastable=bcast)
......
...@@ -214,8 +214,8 @@ def is_positive(v): ...@@ -214,8 +214,8 @@ def is_positive(v):
logger.debug('is_positive: %s' % str(v)) logger.debug('is_positive: %s' % str(v))
if v.owner and v.owner.op == tensor.pow: if v.owner and v.owner.op == tensor.pow:
try: try:
exponent = tensor.get_constant_value(v.owner.inputs[1]) exponent = tensor.get_scalar_constant_value(v.owner.inputs[1])
except TypeError: except tensor.basic.NotScalarConstantError:
return False return False
if 0 == exponent % 2: if 0 == exponent % 2:
return True return True
......
...@@ -135,8 +135,8 @@ def scan(fn, ...@@ -135,8 +135,8 @@ def scan(fn,
n_fixed_steps = int(n_steps) n_fixed_steps = int(n_steps)
else: else:
try: try:
n_fixed_steps = opt.get_constant_value(n_steps) n_fixed_steps = opt.get_scalar_constant_value(n_steps)
except (TypeError, AttributeError): except tensor.basic.NotScalarConstantError:
n_fixed_steps = None n_fixed_steps = None
# Check n_steps is an int # Check n_steps is an int
......
...@@ -335,7 +335,7 @@ def scan(fn, ...@@ -335,7 +335,7 @@ def scan(fn,
T_value = int(n_steps) T_value = int(n_steps)
else: else:
try: try:
T_value = opt.get_constant_value(n_steps) T_value = opt.get_scalar_constant_value(n_steps)
except (TypeError, AttributeError): except (TypeError, AttributeError):
T_value = None T_value = None
......
...@@ -24,7 +24,7 @@ from theano.compile.pfunc import rebuild_collect_shared ...@@ -24,7 +24,7 @@ from theano.compile.pfunc import rebuild_collect_shared
from theano import gof from theano import gof
from theano import tensor, scalar from theano import tensor, scalar
from theano.gof.python25 import all from theano.gof.python25 import all
from theano.tensor.basic import get_constant_value from theano.tensor.basic import get_scalar_constant_value
# Logging function for sending warning or info # Logging function for sending warning or info
......
...@@ -363,8 +363,8 @@ def scan(fn, ...@@ -363,8 +363,8 @@ def scan(fn,
n_fixed_steps = int(n_steps) n_fixed_steps = int(n_steps)
else: else:
try: try:
n_fixed_steps = opt.get_constant_value(n_steps) n_fixed_steps = opt.get_scalar_constant_value(n_steps)
except (TypeError, AttributeError): except tensor.basic.NotScalarConstantError:
n_fixed_steps = None n_fixed_steps = None
# Check n_steps is an int # Check n_steps is an int
......
...@@ -18,7 +18,7 @@ import numpy ...@@ -18,7 +18,7 @@ import numpy
import theano import theano
from theano import tensor from theano import tensor
from theano.tensor import opt, get_constant_value from theano.tensor import opt, get_scalar_constant_value
from theano import gof from theano import gof
from theano.gof.python25 import maxsize, any from theano.gof.python25 import maxsize, any
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
...@@ -1164,14 +1164,14 @@ class ScanMerge(gof.Optimizer): ...@@ -1164,14 +1164,14 @@ class ScanMerge(gof.Optimizer):
nsteps = node.inputs[0] nsteps = node.inputs[0]
try: try:
nsteps = int(get_constant_value(nsteps)) nsteps = int(get_scalar_constant_value(nsteps))
except TypeError: except tensor.NotScalarConstantError:
pass pass
rep_nsteps = rep.inputs[0] rep_nsteps = rep.inputs[0]
try: try:
rep_nsteps = int(get_constant_value(rep_nsteps)) rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
except TypeError: except tensor.NotScalarConstantError:
pass pass
# Check to see if it is an input of a different node # Check to see if it is an input of a different node
......
...@@ -25,7 +25,7 @@ from theano.compile.pfunc import rebuild_collect_shared ...@@ -25,7 +25,7 @@ from theano.compile.pfunc import rebuild_collect_shared
from theano import gof from theano import gof
from theano import tensor, scalar from theano import tensor, scalar
from theano.gof.python25 import all, OrderedDict from theano.gof.python25 import all, OrderedDict
from theano.tensor.basic import get_constant_value from theano.tensor.basic import get_scalar_constant_value
################ Utility Functions and Classes ####################### ################ Utility Functions and Classes #######################
...@@ -308,7 +308,7 @@ def isNaN_or_Inf_or_None(x): ...@@ -308,7 +308,7 @@ def isNaN_or_Inf_or_None(x):
isStr = False isStr = False
if not isNaN and not isInf: if not isNaN and not isInf:
try: try:
val = get_constant_value(x) val = get_scalar_constant_value(x)
isInf = numpy.isinf(val) isInf = numpy.isinf(val)
isNaN = numpy.isnan(val) isNaN = numpy.isnan(val)
except Exception: except Exception:
......
...@@ -463,70 +463,88 @@ def _allclose(a, b, rtol=None, atol=None): ...@@ -463,70 +463,88 @@ def _allclose(a, b, rtol=None, atol=None):
return numpy.allclose(a, b, atol=atol_, rtol=rtol_) return numpy.allclose(a, b, atol=atol_, rtol=rtol_)
class NotConstantError(TypeError): class NotScalarConstantError(Exception):
""" """
Raised by get_constant_value if called on something that is Raised by get_scalar_constant_value if called on something that is
not constant. not a scalar constant.
For now it is a TypeError, to maintain the old interface """
that get_constant_value should raise a TypeError in this
situation. However, this is unsafe because get_constant_value class EmptyConstantError(NotScalarConstantError):
could inadvertently raise a TypeError if it has a bug. """
So we should eventually make NotConstantError derive Raised by get_scalar_const_value if called on something that is a
from Exception directly, and modify all code that uses zero dimensional constant.
get_constant_value to catch this more specific exception.
""" """
pass
def get_constant_value(v): def get_scalar_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
this function digs through them. this function digs through them.
If `v` is not some view of constant data, then raise a NotConstantError. If `v` is not some view of constant scalar data, then raise a NotScalarConstantError.
:note: There may be another function similar to this one in the :note: There may be another function similar to this one in the
code, but I'm not sure where it is. code, but I'm not sure where it is.
""" """
if isinstance(v, Constant): if v is None:
if getattr(v.tag, 'unique_value', None) is not None: # None is not a scalar (and many uses of this function seem to depend
data = v.tag.unique_value # on passing it None)
else: raise NotScalarConstantError()
data = v.data
if isinstance(v, (int, float)):
return numpy.asarray(v)
def numpy_scalar(n):
""" Return a scalar stored in a numpy ndarray, or raise
NotScalarConstantError if the numpy ndarray is not a scalar
"""
# handle case where data is numpy.array([]) # handle case where data is numpy.array([])
if hasattr(data, 'shape') and len(data.shape) == 0 or \ if data.ndim > 0 and (len(data.shape) == 0 or
__builtins__['max'](data.shape) == 0: __builtins__['max'](data.shape) == 0):
assert numpy.all(numpy.array([]) == data) assert numpy.all(numpy.array([]) == data)
return data raise EmptyConstantError()
try: try:
numpy.complex(data) # works for all numeric scalars numpy.complex(data) # works for all numeric scalars
return data return data
except Exception: except Exception:
raise NotConstantError( raise NotScalarConstantError(
'v.data is non-numeric, non-scalar, or has more than one' 'v.data is non-numeric, non-scalar, or has more than one'
' unique value', v) ' unique value', n)
if isinstance(v, numpy.ndarray):
return numpy_scalar(v)
if isinstance(v, Constant):
if getattr(v.tag, 'unique_value', None) is not None:
data = v.tag.unique_value
else:
data = v.data
return numpy_scalar(data)
if v.owner: if v.owner:
if isinstance(v.owner.op, Alloc): if isinstance(v.owner.op, Alloc):
return get_constant_value(v.owner.inputs[0]) return get_scalar_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, DimShuffle): if isinstance(v.owner.op, DimShuffle):
return get_constant_value(v.owner.inputs[0]) return get_scalar_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, Rebroadcast): if isinstance(v.owner.op, Rebroadcast):
return get_constant_value(v.owner.inputs[0]) return get_scalar_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, Elemwise) and \ if isinstance(v.owner.op, Elemwise) and \
isinstance(v.owner.op.scalar_op, scal.Second): isinstance(v.owner.op.scalar_op, scal.Second):
shape, val = v.owner.inputs shape, val = v.owner.inputs
return get_constant_value(val) return get_scalar_constant_value(val)
if isinstance(v.owner.op, scal.Second): if isinstance(v.owner.op, scal.Second):
x, y = v.owner.inputs x, y = v.owner.inputs
return get_constant_value(y) return get_scalar_constant_value(y)
# Don't act as the constant_folding optimization here as this # Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would # fct is used too early in the optimization phase. This would
# mess with the stabilization optimization. # mess with the stabilization optimization.
if (isinstance(v.owner.op, Elemwise) and isinstance( if (isinstance(v.owner.op, Elemwise) and isinstance(
v.owner.op.scalar_op, scal.Cast)) or \ v.owner.op.scalar_op, scal.Cast)) or \
isinstance(v.owner.op, scal.Cast): isinstance(v.owner.op, scal.Cast):
const = get_constant_value(v.owner.inputs[0]) const = get_scalar_constant_value(v.owner.inputs[0])
ret = [[None]] ret = [[None]]
v.owner.op.perform(v.owner, [const], ret) v.owner.op.perform(v.owner, [const], ret)
return ret[0][0] return ret[0][0]
...@@ -563,7 +581,7 @@ def get_constant_value(v): ...@@ -563,7 +581,7 @@ def get_constant_value(v):
# axis. # axis.
ret = v.owner.inputs[0].owner.inputs[ ret = v.owner.inputs[0].owner.inputs[
v.owner.op.idx_list[0] + 1] v.owner.op.idx_list[0] + 1]
ret = get_constant_value(ret) ret = get_scalar_constant_value(ret)
# join can cast implicitly its input in some case. # join can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype) return theano._asarray(ret, dtype=v.type.dtype)
if (v.owner.inputs[0].owner and if (v.owner.inputs[0].owner and
...@@ -576,7 +594,7 @@ def get_constant_value(v): ...@@ -576,7 +594,7 @@ def get_constant_value(v):
len(v.owner.op.idx_list) == 1): len(v.owner.op.idx_list) == 1):
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]] ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]]
ret = get_constant_value(ret) ret = get_scalar_constant_value(ret)
# MakeVector can cast implicitly its input in some case. # MakeVector can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype) return theano._asarray(ret, dtype=v.type.dtype)
...@@ -589,7 +607,7 @@ def get_constant_value(v): ...@@ -589,7 +607,7 @@ def get_constant_value(v):
v.owner.op.idx_list[0]]: v.owner.op.idx_list[0]]:
return numpy.asarray(1) return numpy.asarray(1)
raise NotConstantError(v) raise NotScalarConstantError(v)
class TensorType(Type): class TensorType(Type):
...@@ -1832,8 +1850,8 @@ class _tensor_py_operators: ...@@ -1832,8 +1850,8 @@ class _tensor_py_operators:
# TO TRUMP NUMPY OPERATORS # TO TRUMP NUMPY OPERATORS
__array_priority__ = 1000 __array_priority__ = 1000
def get_constant_value(self): def get_scalar_constant_value(self):
return get_constant_value(self) return get_scalar_constant_value(self)
def zeros_like(model): def zeros_like(model):
return zeros_like(model) return zeros_like(model)
...@@ -2361,10 +2379,10 @@ class SpecifyShape(Op): ...@@ -2361,10 +2379,10 @@ class SpecifyShape(Op):
new_shape = [] new_shape = []
for dim in xrange(node.inputs[0].ndim): for dim in xrange(node.inputs[0].ndim):
try: try:
s = get_constant_value(node.inputs[1][dim]) s = get_scalar_constant_value(node.inputs[1][dim])
s = as_tensor_variable(s) s = as_tensor_variable(s)
new_shape.append(s) new_shape.append(s)
except TypeError: except NotScalarConstantError:
new_shape.append(node.inputs[1][dim]) new_shape.append(node.inputs[1][dim])
assert len(new_shape) == len(xshape) assert len(new_shape) == len(xshape)
...@@ -2656,14 +2674,22 @@ def max(x, axis=None, keepdims=False): ...@@ -2656,14 +2674,22 @@ def max(x, axis=None, keepdims=False):
:note: we return an error as numpy when we reduce a dim with a shape of 0 :note: we return an error as numpy when we reduce a dim with a shape of 0
""" """
if isinstance(axis, (list, tuple)) and len(axis) > 1: # We have a choice of implementing this call with the
# CAReduce op or the MaxAndArgmax op.
# MaxAndArgmax supports grad and Rop, so we prefer to use that.
# CAReduce is faster, but optimizations will replace MaxAndArgmax[0]
# with CAReduce at compile time, so at this stage the important
# thing is supporting all user interface features, not speed.
# Some cases can be implemented only with CAReduce.
# We thus prefer to use MaxAndArgmax, if possible. It does not
# support all axis arguments, so we may need to fall back to CAReduce.
try:
out = max_and_argmax(x, axis)[0]
except Exception:
out = CAReduce(scal.maximum, axis)(x) out = CAReduce(scal.maximum, axis)(x)
else:
try:
const = get_constant_value(axis)
out = CAReduce(scal.maximum, list(const))(x)
except Exception:
out = max_and_argmax(x, axis)[0]
if keepdims: if keepdims:
out = makeKeepDims(x, out, axis) out = makeKeepDims(x, out, axis)
...@@ -3271,8 +3297,8 @@ class Alloc(gof.Op): ...@@ -3271,8 +3297,8 @@ class Alloc(gof.Op):
(i, s_as_str)) (i, s_as_str))
# if s is constant 1, then we're broadcastable in that dim # if s is constant 1, then we're broadcastable in that dim
try: try:
const_shp = get_constant_value(s) const_shp = get_scalar_constant_value(s)
except TypeError: except NotScalarConstantError:
const_shp = None const_shp = None
bcast.append(numpy.all(1 == const_shp)) bcast.append(numpy.all(1 == const_shp))
otype = TensorType(dtype=v.dtype, broadcastable=bcast) otype = TensorType(dtype=v.dtype, broadcastable=bcast)
...@@ -3811,16 +3837,16 @@ def get_idx_list(inputs, idx_list): ...@@ -3811,16 +3837,16 @@ def get_idx_list(inputs, idx_list):
def extract_constant(x): def extract_constant(x):
''' '''
This function is basically a call to tensor.get_constant_value. The This function is basically a call to tensor.get_scalar_constant_value. The
main difference is the behaviour in case of failure. While main difference is the behaviour in case of failure. While
get_constant_value raises an TypeError, this function returns x, get_scalar_constant_value raises an TypeError, this function returns x,
as a tensor if possible. If x is a ScalarVariable from a as a tensor if possible. If x is a ScalarVariable from a
scalar_from_tensor, we remove the conversion. If x is just a scalar_from_tensor, we remove the conversion. If x is just a
ScalarVariable, we convert it to a tensor with tensor_from_scalar. ScalarVariable, we convert it to a tensor with tensor_from_scalar.
''' '''
try: try:
x = get_constant_value(x) x = get_scalar_constant_value(x)
except Exception: except NotScalarConstantError:
pass pass
if (isinstance(x, scal.ScalarVariable) or if (isinstance(x, scal.ScalarVariable) or
isinstance(x, scal.sharedvar.ScalarSharedVariable)): isinstance(x, scal.sharedvar.ScalarSharedVariable)):
...@@ -5419,11 +5445,11 @@ class Join(Op): ...@@ -5419,11 +5445,11 @@ class Join(Op):
# Axis can also be a constant # Axis can also be a constant
if not isinstance(axis, int): if not isinstance(axis, int):
try: try:
# Note : `get_constant_value` returns a ndarray not a # Note : `get_scalar_constant_value` returns a ndarray not a
# int # int
axis = int(get_constant_value(axis)) axis = int(get_scalar_constant_value(axis))
except TypeError: except NotScalarConstantError:
pass pass
if isinstance(axis, int): if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the # Basically, broadcastable -> length 1, but the
...@@ -5790,9 +5816,9 @@ class Reshape(Op): ...@@ -5790,9 +5816,9 @@ class Reshape(Op):
# Try to see if we can infer that y has a constant value of 1. # Try to see if we can infer that y has a constant value of 1.
# If so, that dimension should be broadcastable. # If so, that dimension should be broadcastable.
try: try:
bcasts[index] = (hasattr(y, 'get_constant_value') and bcasts[index] = (hasattr(y, 'get_scalar_constant_value') and
y.get_constant_value() == 1) y.get_scalar_constant_value() == 1)
except TypeError: except NotScalarConstantError:
pass pass
return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)]) return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)])
...@@ -5865,10 +5891,10 @@ class Reshape(Op): ...@@ -5865,10 +5891,10 @@ class Reshape(Op):
for i in xrange(self.ndim): for i in xrange(self.ndim):
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0]) default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0])
try: try:
os_i = get_constant_value(node.inputs[1][i]).item() os_i = get_scalar_constant_value(node.inputs[1][i]).item()
if os_i == -1: if os_i == -1:
os_i = default_os_i os_i = default_os_i
except TypeError: except NotScalarConstantError:
os_i = default_os_i os_i = default_os_i
oshape.append(os_i) oshape.append(os_i)
return [tuple(oshape)] return [tuple(oshape)]
...@@ -6148,9 +6174,9 @@ class ARange(Op): ...@@ -6148,9 +6174,9 @@ class ARange(Op):
def is_constant_value(var, value): def is_constant_value(var, value):
try: try:
v = get_constant_value(var) v = get_scalar_constant_value(var)
return numpy.all(v == value) return numpy.all(v == value)
except Exception: except NotScalarConstantError:
pass pass
return False return False
......
...@@ -1614,7 +1614,7 @@ def local_gemm_to_ger(node): ...@@ -1614,7 +1614,7 @@ def local_gemm_to_ger(node):
xv = x.dimshuffle(0) xv = x.dimshuffle(0)
yv = y.dimshuffle(1) yv = y.dimshuffle(1)
try: try:
bval = T.get_constant_value(b) bval = T.get_scalar_constant_value(b)
except TypeError: except TypeError:
# b isn't a constant, GEMM is doing useful pre-scaling # b isn't a constant, GEMM is doing useful pre-scaling
return return
......
...@@ -262,15 +262,15 @@ class RepeatOp(theano.Op): ...@@ -262,15 +262,15 @@ class RepeatOp(theano.Op):
broadcastable=[False] broadcastable=[False]
else: else:
try: try:
const_reps = basic.get_constant_value(repeats) const_reps = basic.get_scalar_constant_value(repeats)
except basic.NotConstantError: except basic.NotScalarConstantError:
const_reps = None const_reps = None
if const_reps == 1: if const_reps == 1:
broadcastable = x.broadcastable broadcastable = x.broadcastable
else: else:
broadcastable = list(x.broadcastable) broadcastable = list(x.broadcastable)
broadcastable[self.axis] = False broadcastable[self.axis] = False
out_type = theano.tensor.TensorType(x.dtype, broadcastable) out_type = theano.tensor.TensorType(x.dtype, broadcastable)
return theano.Apply(self, [x, repeats], [out_type()]) return theano.Apply(self, [x, repeats], [out_type()])
......
...@@ -15,7 +15,7 @@ import logging ...@@ -15,7 +15,7 @@ import logging
import numpy import numpy
import theano import theano
from theano.tensor import (as_tensor_variable, blas, get_constant_value, from theano.tensor import (as_tensor_variable, blas, get_scalar_constant_value,
patternbroadcast) patternbroadcast)
from theano import OpenMPOp, config from theano import OpenMPOp, config
from theano.gof import Apply from theano.gof import Apply
...@@ -90,7 +90,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None, ...@@ -90,7 +90,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
image_shape = list(image_shape) image_shape = list(image_shape)
for i in xrange(len(image_shape)): for i in xrange(len(image_shape)):
if image_shape[i] is not None: if image_shape[i] is not None:
image_shape[i] = get_constant_value( image_shape[i] = get_scalar_constant_value(
as_tensor_variable(image_shape[i])) as_tensor_variable(image_shape[i]))
assert str(image_shape[i].dtype).startswith('int') assert str(image_shape[i].dtype).startswith('int')
image_shape[i] = int(image_shape[i]) image_shape[i] = int(image_shape[i])
...@@ -98,7 +98,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None, ...@@ -98,7 +98,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
filter_shape = list(filter_shape) filter_shape = list(filter_shape)
for i in xrange(len(filter_shape)): for i in xrange(len(filter_shape)):
if filter_shape[i] is not None: if filter_shape[i] is not None:
filter_shape[i] = get_constant_value( filter_shape[i] = get_scalar_constant_value(
as_tensor_variable(filter_shape[i])) as_tensor_variable(filter_shape[i]))
assert str(filter_shape[i].dtype).startswith('int') assert str(filter_shape[i].dtype).startswith('int')
filter_shape[i] = int(filter_shape[i]) filter_shape[i] = int(filter_shape[i])
......
...@@ -1409,8 +1409,8 @@ def _check_rows_is_arange_len_labels(rows, labels): ...@@ -1409,8 +1409,8 @@ def _check_rows_is_arange_len_labels(rows, labels):
def _is_const(z, val, approx=False): def _is_const(z, val, approx=False):
try: try:
maybe = opt.get_constant_value(z) maybe = opt.get_scalar_constant_value(z)
except TypeError: except tensor.NotScalarConstantError:
return False return False
if approx: if approx:
return numpy.allclose(maybe, val) return numpy.allclose(maybe, val)
......
...@@ -14,7 +14,7 @@ from theano.compile import optdb ...@@ -14,7 +14,7 @@ from theano.compile import optdb
from theano.configparser import AddConfigVar, BoolParam from theano.configparser import AddConfigVar, BoolParam
from theano.printing import pprint, debugprint from theano.printing import pprint, debugprint
from theano.tensor import basic as tensor from theano.tensor import basic as tensor
from theano.tensor import elemwise, opt from theano.tensor import elemwise, opt, NotScalarConstantError
############ ############
...@@ -136,9 +136,9 @@ def _is_1(expr): ...@@ -136,9 +136,9 @@ def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1 """rtype bool. True iff expr is a constant close to 1
""" """
try: try:
v = opt.get_constant_value(expr) v = opt.get_scalar_constant_value(expr)
return numpy.allclose(v, 1) return numpy.allclose(v, 1)
except TypeError: except tensor.NotScalarConstantError:
return False return False
log1msigm_to_softplus = gof.PatternSub( log1msigm_to_softplus = gof.PatternSub(
...@@ -275,9 +275,9 @@ def is_neg(var): ...@@ -275,9 +275,9 @@ def is_neg(var):
if apply.op == tensor.mul and len(apply.inputs) >= 2: if apply.op == tensor.mul and len(apply.inputs) >= 2:
for idx, mul_input in enumerate(apply.inputs): for idx, mul_input in enumerate(apply.inputs):
try: try:
constant = opt.get_constant_value(mul_input) constant = opt.get_scalar_constant_value(mul_input)
is_minus_1 = numpy.allclose(constant, -1) is_minus_1 = numpy.allclose(constant, -1)
except TypeError: except NotScalarConstantError:
is_minus_1 = False is_minus_1 = False
if is_minus_1: if is_minus_1:
# Found a multiplication by -1. # Found a multiplication by -1.
...@@ -647,7 +647,7 @@ def local_1msigmoid(node): ...@@ -647,7 +647,7 @@ def local_1msigmoid(node):
return # graph is using both sigm and 1-sigm return # graph is using both sigm and 1-sigm
if sub_r.owner and sub_r.owner.op == sigmoid: if sub_r.owner and sub_r.owner.op == sigmoid:
try: try:
val_l = opt.get_constant_value(sub_l) val_l = opt.get_scalar_constant_value(sub_l)
except Exception, e: except Exception, e:
return return
if numpy.allclose(numpy.sum(val_l), 1): if numpy.allclose(numpy.sum(val_l), 1):
......
...@@ -30,10 +30,10 @@ class TestConv2D(utt.InferShapeTester): ...@@ -30,10 +30,10 @@ class TestConv2D(utt.InferShapeTester):
verify_grad=True, should_raise=False): verify_grad=True, should_raise=False):
if N_image_shape is None: if N_image_shape is None:
N_image_shape = [T.get_constant_value(T. N_image_shape = [T.get_scalar_constant_value(T.
as_tensor_variable(x)) for x in image_shape] as_tensor_variable(x)) for x in image_shape]
if N_filter_shape is None: if N_filter_shape is None:
N_filter_shape = [T.get_constant_value(T. N_filter_shape = [T.get_scalar_constant_value(T.
as_tensor_variable(x)) for x in filter_shape] as_tensor_variable(x)) for x in filter_shape]
if input is None: if input is None:
......
...@@ -33,7 +33,7 @@ from theano.gof.opt import (Optimizer, pre_constant_merge, ...@@ -33,7 +33,7 @@ from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer) pre_greedy_local_optimizer)
from theano.gof.opt import merge_optimizer from theano.gof.opt import merge_optimizer
from theano.gof import toolbox, DestroyHandler from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value, ShapeError from basic import get_scalar_constant_value, ShapeError, NotScalarConstantError
theano.configparser.AddConfigVar('on_shape_error', theano.configparser.AddConfigVar('on_shape_error',
...@@ -92,10 +92,10 @@ def scalarconsts_rest(inputs): ...@@ -92,10 +92,10 @@ def scalarconsts_rest(inputs):
nonconsts = [] nonconsts = []
for i in inputs: for i in inputs:
try: try:
v = get_constant_value(i) v = get_scalar_constant_value(i)
consts.append(v) consts.append(v)
origconsts.append(i) origconsts.append(i)
except Exception: except NotScalarConstantError:
nonconsts.append(i) nonconsts.append(i)
return consts, origconsts, nonconsts return consts, origconsts, nonconsts
...@@ -125,7 +125,13 @@ def broadcast_like(value, template, fgraph, dtype=None): ...@@ -125,7 +125,13 @@ def broadcast_like(value, template, fgraph, dtype=None):
if rval.broadcastable[i] if rval.broadcastable[i]
and not template.broadcastable[i]]) and not template.broadcastable[i]])
assert rval.type.dtype == dtype assert rval.type.dtype == dtype
assert rval.type.broadcastable == template.broadcastable
if rval.type.broadcastable != template.broadcastable:
raise AssertionError("rval.type.broadcastable is " +
str(rval.type.broadcastable) +
" but template.broadcastable is" +
str(template.broadcastable))
return rval return rval
...@@ -322,15 +328,15 @@ def local_0_dot_x(node): ...@@ -322,15 +328,15 @@ def local_0_dot_x(node):
y = node.inputs[1] y = node.inputs[1]
replace = False replace = False
try: try:
if get_constant_value(x) == 0: if get_scalar_constant_value(x) == 0:
replace = True replace = True
except TypeError: except NotScalarConstantError:
pass pass
try: try:
if get_constant_value(y) == 0: if get_scalar_constant_value(y) == 0:
replace = True replace = True
except TypeError: except NotScalarConstantError:
pass pass
if replace: if replace:
...@@ -1177,9 +1183,9 @@ def local_subtensor_make_vector(node): ...@@ -1177,9 +1183,9 @@ def local_subtensor_make_vector(node):
elif isinstance(idx, Variable): elif isinstance(idx, Variable):
# if it is a constant we can do something with it # if it is a constant we can do something with it
try: try:
v = get_constant_value(idx) v = get_scalar_constant_value(idx)
return [x.owner.inputs[v]] return [x.owner.inputs[v]]
except Exception: except NotScalarConstantError:
pass pass
else: else:
# it is a slice of ints and/or Variables # it is a slice of ints and/or Variables
...@@ -1315,13 +1321,13 @@ def local_remove_useless_assert(node): ...@@ -1315,13 +1321,13 @@ def local_remove_useless_assert(node):
cond = [] cond = []
for c in node.inputs[1:]: for c in node.inputs[1:]:
try: try:
const = get_constant_value(c) const = get_scalar_constant_value(c)
if 0 != const.ndim or const == 0: if 0 != const.ndim or const == 0:
#Should we raise an error here? How to be sure it #Should we raise an error here? How to be sure it
#is not catched? #is not catched?
cond.append(c) cond.append(c)
except TypeError: except NotScalarConstantError:
cond.append(c) cond.append(c)
if len(cond) == 0: if len(cond) == 0:
...@@ -1477,7 +1483,7 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1477,7 +1483,7 @@ def local_upcast_elemwise_constant_inputs(node):
else: else:
try: try:
# works only for scalars # works only for scalars
cval_i = get_constant_value(i) cval_i = get_scalar_constant_value(i)
if all(i.broadcastable): if all(i.broadcastable):
new_inputs.append(T.shape_padleft( new_inputs.append(T.shape_padleft(
T.cast(cval_i, output_dtype), T.cast(cval_i, output_dtype),
...@@ -1490,7 +1496,7 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1490,7 +1496,7 @@ def local_upcast_elemwise_constant_inputs(node):
*[shape_i(d)(i) for d in xrange(i.ndim)])) *[shape_i(d)(i) for d in xrange(i.ndim)]))
#print >> sys.stderr, "AAA", #print >> sys.stderr, "AAA",
#*[Shape_i(d)(i) for d in xrange(i.ndim)] #*[Shape_i(d)(i) for d in xrange(i.ndim)]
except TypeError: except NotScalarConstantError:
#for the case of a non-scalar #for the case of a non-scalar
if isinstance(i, T.TensorConstant): if isinstance(i, T.TensorConstant):
new_inputs.append(T.cast(i, output_dtype)) new_inputs.append(T.cast(i, output_dtype))
...@@ -1550,8 +1556,8 @@ def local_useless_subtensor(node): ...@@ -1550,8 +1556,8 @@ def local_useless_subtensor(node):
length_pos = shape_of[node.inputs[0]][pos] length_pos = shape_of[node.inputs[0]][pos]
try: try:
length_pos_data = get_constant_value(length_pos) length_pos_data = get_scalar_constant_value(length_pos)
except TypeError: except NotScalarConstantError:
pass pass
if isinstance(idx.stop, int): if isinstance(idx.stop, int):
...@@ -2032,9 +2038,9 @@ def local_incsubtensor_of_allocs(node): ...@@ -2032,9 +2038,9 @@ def local_incsubtensor_of_allocs(node):
y = node.inputs[1] y = node.inputs[1]
replace = False replace = False
try: try:
if get_constant_value(y) == 0: if get_scalar_constant_value(y) == 0:
replace = True replace = True
except TypeError: except NotScalarConstantError:
pass pass
if replace: if replace:
...@@ -2059,13 +2065,13 @@ def local_setsubtensor_of_allocs(node): ...@@ -2059,13 +2065,13 @@ def local_setsubtensor_of_allocs(node):
replace_y = None replace_y = None
try: try:
replace_x = get_constant_value(x) replace_x = get_scalar_constant_value(x)
except TypeError: except NotScalarConstantError:
pass pass
try: try:
replace_y = get_constant_value(y) replace_y = get_scalar_constant_value(y)
except TypeError: except NotScalarConstantError:
pass pass
if (replace_x == replace_y and if (replace_x == replace_y and
...@@ -2253,24 +2259,24 @@ def local_mul_switch_sink(node): ...@@ -2253,24 +2259,24 @@ def local_mul_switch_sink(node):
if i.owner and i.owner.op == T.switch: if i.owner and i.owner.op == T.switch:
switch = i.owner switch = i.owner
try: try:
if get_constant_value(switch.inputs[1]) == 0.: if get_scalar_constant_value(switch.inputs[1]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx + 1:] listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], 0, fct = [T.switch(switch.inputs[0], 0,
T.mul(*(listmul + [switch.inputs[2]])))] T.mul(*(listmul + [switch.inputs[2]])))]
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan 0].type.values_eq_approx_remove_nan
return fct return fct
except TypeError: except NotScalarConstantError:
pass pass
try: try:
if get_constant_value(switch.inputs[2]) == 0.: if get_scalar_constant_value(switch.inputs[2]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx + 1:] listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], fct = [T.switch(switch.inputs[0],
T.mul(*(listmul + [switch.inputs[1]])), 0)] T.mul(*(listmul + [switch.inputs[1]])), 0)]
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan 0].type.values_eq_approx_remove_nan
return fct return fct
except TypeError: except NotScalarConstantError:
pass pass
return False return False
...@@ -2295,22 +2301,22 @@ def local_div_switch_sink(node): ...@@ -2295,22 +2301,22 @@ def local_div_switch_sink(node):
if node.inputs[0].owner and node.inputs[0].owner.op == T.switch: if node.inputs[0].owner and node.inputs[0].owner.op == T.switch:
switch = node.inputs[0].owner switch = node.inputs[0].owner
try: try:
if get_constant_value(switch.inputs[1]) == 0.: if get_scalar_constant_value(switch.inputs[1]) == 0.:
fct = [T.switch(switch.inputs[0], 0, fct = [T.switch(switch.inputs[0], 0,
op(switch.inputs[2], node.inputs[1]))] op(switch.inputs[2], node.inputs[1]))]
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan 0].type.values_eq_approx_remove_nan
return fct return fct
except TypeError: except NotScalarConstantError:
pass pass
try: try:
if get_constant_value(switch.inputs[2]) == 0.: if get_scalar_constant_value(switch.inputs[2]) == 0.:
fct = [T.switch(switch.inputs[0], fct = [T.switch(switch.inputs[0],
op(switch.inputs[1], node.inputs[1]), 0)] op(switch.inputs[1], node.inputs[1]), 0)]
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan 0].type.values_eq_approx_remove_nan
return fct return fct
except TypeError: except NotScalarConstantError:
pass pass
return False return False
...@@ -2375,7 +2381,7 @@ if 0: ...@@ -2375,7 +2381,7 @@ if 0:
def tmp(thing): def tmp(thing):
try: try:
return T.get_constant_value(thing) return T.get_scalar_constant_value(thing)
except (TypeError, ValueError), e: except (TypeError, ValueError), e:
print e, thing.owner.inputs[0] print e, thing.owner.inputs[0]
return None return None
...@@ -2702,8 +2708,8 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2702,8 +2708,8 @@ class Canonizer(gof.LocalOptimizer):
""" """
if isinstance(v, Variable): if isinstance(v, Variable):
try: try:
return get_constant_value(v) return get_scalar_constant_value(v)
except TypeError: except NotScalarConstantError:
return None return None
else: else:
return v return v
...@@ -3204,15 +3210,15 @@ def local_sum_alloc(node): ...@@ -3204,15 +3210,15 @@ def local_sum_alloc(node):
if (node.op.axis is None or if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))): node.op.axis == tuple(range(input.ndim))):
try: try:
val = get_constant_value(input) val = get_scalar_constant_value(input)
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] * T.mul(*shapes) val = val.reshape(1)[0] * T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)] return [T.cast(val, dtype=node.outputs[0].dtype)]
except TypeError: except NotScalarConstantError:
pass pass
else: else:
try: try:
val = get_constant_value(input) val = get_scalar_constant_value(input)
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] val = val.reshape(1)[0]
to_prod = [shapes[i] for i in xrange(len(shapes)) to_prod = [shapes[i] for i in xrange(len(shapes))
...@@ -3222,7 +3228,7 @@ def local_sum_alloc(node): ...@@ -3222,7 +3228,7 @@ def local_sum_alloc(node):
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype), return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in xrange(len(shapes)) *[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])] if i not in node.op.axis])]
except TypeError: except NotScalarConstantError:
pass pass
...@@ -3282,8 +3288,8 @@ def local_mul_zero(node): ...@@ -3282,8 +3288,8 @@ def local_mul_zero(node):
for i in node.inputs: for i in node.inputs:
try: try:
value = get_constant_value(i) value = get_scalar_constant_value(i)
except TypeError: except NotScalarConstantError:
continue continue
#print 'MUL by value', value, node.inputs #print 'MUL by value', value, node.inputs
if N.all(value == 0): if N.all(value == 0):
...@@ -3520,8 +3526,8 @@ def local_add_specialize(node): ...@@ -3520,8 +3526,8 @@ def local_add_specialize(node):
new_inputs = [] new_inputs = []
for input in node.inputs: for input in node.inputs:
try: try:
y = get_constant_value(input) y = get_scalar_constant_value(input)
except TypeError: except NotScalarConstantError:
y = input y = input
if numpy.all(y == 0.0): if numpy.all(y == 0.0):
continue continue
...@@ -3614,7 +3620,7 @@ def local_abs_merge(node): ...@@ -3614,7 +3620,7 @@ def local_abs_merge(node):
if i.owner and i.owner.op == T.abs_: if i.owner and i.owner.op == T.abs_:
inputs.append(i.owner.inputs[0]) inputs.append(i.owner.inputs[0])
else: else:
const = get_constant_value(i) const = get_scalar_constant_value(i)
if not (const >= 0).all(): if not (const >= 0).all():
return False return False
inputs.append(i) inputs.append(i)
...@@ -3880,9 +3886,9 @@ def _is_1(expr): ...@@ -3880,9 +3886,9 @@ def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1 """rtype bool. True iff expr is a constant close to 1
""" """
try: try:
v = get_constant_value(expr) v = get_scalar_constant_value(expr)
return numpy.allclose(v, 1) return numpy.allclose(v, 1)
except TypeError: except NotScalarConstantError:
return False return False
...@@ -3890,9 +3896,9 @@ def _is_minus1(expr): ...@@ -3890,9 +3896,9 @@ def _is_minus1(expr):
"""rtype bool. True iff expr is a constant close to -1 """rtype bool. True iff expr is a constant close to -1
""" """
try: try:
v = get_constant_value(expr) v = get_scalar_constant_value(expr)
return numpy.allclose(v, -1) return numpy.allclose(v, -1)
except TypeError: except NotScalarConstantError:
return False return False
#1+erf(x)=>erfc(-x) #1+erf(x)=>erfc(-x)
...@@ -4132,8 +4138,8 @@ def local_grad_log_erfc_neg(node): ...@@ -4132,8 +4138,8 @@ def local_grad_log_erfc_neg(node):
mul_neg = T.mul(*mul_inputs) mul_neg = T.mul(*mul_inputs)
try: try:
cst2 = get_constant_value(mul_neg.owner.inputs[0]) cst2 = get_scalar_constant_value(mul_neg.owner.inputs[0])
except TypeError: except NotScalarConstantError:
return False return False
if len(mul_neg.owner.inputs) == 2: if len(mul_neg.owner.inputs) == 2:
...@@ -4159,8 +4165,8 @@ def local_grad_log_erfc_neg(node): ...@@ -4159,8 +4165,8 @@ def local_grad_log_erfc_neg(node):
x = erfc_x x = erfc_x
try: try:
cst = get_constant_value(erfc_x.owner.inputs[0]) cst = get_scalar_constant_value(erfc_x.owner.inputs[0])
except TypeError: except NotScalarConstantError:
return False return False
if cst2 != -cst * 2: if cst2 != -cst * 2:
return False return False
......
...@@ -41,7 +41,7 @@ from theano.gof.python25 import any, all ...@@ -41,7 +41,7 @@ from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gof import InconsistencyError, toolbox from theano.gof import InconsistencyError, toolbox
from basic import get_constant_value from basic import get_scalar_constant_value, NotScalarConstantError
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal from theano import scalar as scal
...@@ -64,8 +64,8 @@ class MaxAndArgmaxOptimizer(Optimizer): ...@@ -64,8 +64,8 @@ class MaxAndArgmaxOptimizer(Optimizer):
if node.op == T._max_and_argmax: if node.op == T._max_and_argmax:
if len(node.outputs[1].clients)==0: if len(node.outputs[1].clients)==0:
try: try:
axis=get_constant_value(node.inputs[1]) axis=get_scalar_constant_value(node.inputs[1])
except (ValueError, TypeError), e: except NotScalarConstantError:
return False return False
new = CAReduce(scal.maximum,axis)(node.inputs[0]) new = CAReduce(scal.maximum,axis)(node.inputs[0])
......
...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor, Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor,
tensor_copy, tensordot, tensordot_grad, TensorType, unbroadcast, tensor_copy, tensordot, tensordot_grad, TensorType, unbroadcast,
var, Join, shape, MaxAndArgmax, lscalar, zvector, exp, var, Join, shape, MaxAndArgmax, lscalar, zvector, exp,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_scalar_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements, tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
...@@ -2140,7 +2140,7 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -2140,7 +2140,7 @@ class T_max_and_argmax(unittest.TestCase):
cost = argmax(x, axis=0).sum() cost = argmax(x, axis=0).sum()
value_error_raised = False value_error_raised = False
gx = grad(cost, x) gx = grad(cost, x)
val = tensor.get_constant_value(gx) val = tensor.get_scalar_constant_value(gx)
assert val == 0.0 assert val == 0.0
def test_grad(self): def test_grad(self):
...@@ -6167,40 +6167,40 @@ def test_dimshuffle_duplicate(): ...@@ -6167,40 +6167,40 @@ def test_dimshuffle_duplicate():
assert success assert success
class T_get_constant_value(unittest.TestCase): class T_get_scalar_constant_value(unittest.TestCase):
def test_get_constant_value(self): def test_get_scalar_constant_value(self):
a = tensor.stack(1, 2, 3) a = tensor.stack(1, 2, 3)
assert get_constant_value(a[0]) == 1 assert get_scalar_constant_value(a[0]) == 1
assert get_constant_value(a[1]) == 2 assert get_scalar_constant_value(a[1]) == 2
assert get_constant_value(a[2]) == 3 assert get_scalar_constant_value(a[2]) == 3
b = tensor.iscalar() b = tensor.iscalar()
a = tensor.stack(b, 2, 3) a = tensor.stack(b, 2, 3)
self.assertRaises(TypeError, get_constant_value, a[0]) self.assertRaises(tensor.basic.NotScalarConstantError, get_scalar_constant_value, a[0])
assert get_constant_value(a[1]) == 2 assert get_scalar_constant_value(a[1]) == 2
assert get_constant_value(a[2]) == 3 assert get_scalar_constant_value(a[2]) == 3
# For now get_constant_value goes through only MakeVector and Join of # For now get_scalar_constant_value goes through only MakeVector and Join of
# scalars. # scalars.
v = tensor.ivector() v = tensor.ivector()
a = tensor.stack(v, 2, 3) a = tensor.stack(v, 2, 3)
self.assertRaises(TypeError, get_constant_value, a[0]) self.assertRaises(tensor.NotScalarConstantError, get_scalar_constant_value, a[0])
self.assertRaises(TypeError, get_constant_value, a[1]) self.assertRaises(tensor.NotScalarConstantError, get_scalar_constant_value, a[1])
self.assertRaises(TypeError, get_constant_value, a[2]) self.assertRaises(tensor.NotScalarConstantError, get_scalar_constant_value, a[2])
# Test the case SubTensor(Shape(v)) when the dimensions # Test the case SubTensor(Shape(v)) when the dimensions
# is broadcastable. # is broadcastable.
v = tensor.row() v = tensor.row()
assert get_constant_value(v.shape[0]) == 1 assert get_scalar_constant_value(v.shape[0]) == 1
def test_subtensor_of_constant(self): def test_subtensor_of_constant(self):
c = constant(rand(5)) c = constant(rand(5))
for i in range(c.value.shape[0]): for i in range(c.value.shape[0]):
assert get_constant_value(c[i]) == c.value[i] assert get_scalar_constant_value(c[i]) == c.value[i]
c = constant(rand(5, 5)) c = constant(rand(5, 5))
for i in range(c.value.shape[0]): for i in range(c.value.shape[0]):
for j in range(c.value.shape[1]): for j in range(c.value.shape[1]):
assert get_constant_value(c[i, j]) == c.value[i, j] assert get_scalar_constant_value(c[i, j]) == c.value[i, j]
class T_as_tensor_variable(unittest.TestCase): class T_as_tensor_variable(unittest.TestCase):
......
...@@ -856,7 +856,7 @@ def test_gt_grad(): ...@@ -856,7 +856,7 @@ def test_gt_grad():
"""A user test that failed. """A user test that failed.
Something about it made Elemwise.grad return something that was Something about it made Elemwise.grad return something that was
too complicated for get_constant_value to recognize as being 0, so too complicated for get_scalar_constant_value to recognize as being 0, so
gradient.grad reported that it was not a valid gradient of an gradient.grad reported that it was not a valid gradient of an
integer. integer.
......
...@@ -936,7 +936,7 @@ class T_fibby(unittest.TestCase): ...@@ -936,7 +936,7 @@ class T_fibby(unittest.TestCase):
if node.op == fibby: if node.op == fibby:
x = node.inputs[0] x = node.inputs[0]
try: try:
if numpy.all(0 == get_constant_value(x)): if numpy.all(0 == get_scalar_constant_value(x)):
return [x] return [x]
except TypeError: except TypeError:
pass pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论