提交 434dd96e authored 作者: Ian Goodfellow's avatar Ian Goodfellow

get rid of dangerous "TypeError = not constant" mechanism

上级 5237b952
...@@ -171,7 +171,8 @@ def get_constant_value(v): ...@@ -171,7 +171,8 @@ def get_constant_value(v):
If theano.sparse is also there, we will look over CSM op. If theano.sparse is also there, we will look over CSM op.
If `v` is not some view of constant data, then raise a TypeError. If `v` is not some view of constant data, then raise a
tensor.basic.NotConstantError.
""" """
if hasattr(theano, 'sparse') and isinstance(v.type, if hasattr(theano, 'sparse') and isinstance(v.type,
theano.sparse.SparseType): theano.sparse.SparseType):
......
...@@ -1618,7 +1618,7 @@ def _is_zero(x): ...@@ -1618,7 +1618,7 @@ def _is_zero(x):
try: try:
constant_value = theano.get_constant_value(x) constant_value = theano.get_constant_value(x)
no_constant_value = False no_constant_value = False
except TypeError: except theano.tensor.basic.NotConstantError:
pass pass
if no_constant_value: if no_constant_value:
......
...@@ -215,7 +215,7 @@ def is_positive(v): ...@@ -215,7 +215,7 @@ def is_positive(v):
if v.owner and v.owner.op == tensor.pow: if v.owner and v.owner.op == tensor.pow:
try: try:
exponent = tensor.get_constant_value(v.owner.inputs[1]) exponent = tensor.get_constant_value(v.owner.inputs[1])
except TypeError: except tensor.basic.NotConstantError:
return False return False
if 0 == exponent % 2: if 0 == exponent % 2:
return True return True
......
...@@ -136,7 +136,7 @@ def scan(fn, ...@@ -136,7 +136,7 @@ def scan(fn,
else: else:
try: try:
n_fixed_steps = opt.get_constant_value(n_steps) n_fixed_steps = opt.get_constant_value(n_steps)
except (TypeError, AttributeError): except tensor.basic.NotConstantError:
n_fixed_steps = None n_fixed_steps = None
# Check n_steps is an int # Check n_steps is an int
......
...@@ -364,7 +364,7 @@ def scan(fn, ...@@ -364,7 +364,7 @@ def scan(fn,
else: else:
try: try:
n_fixed_steps = opt.get_constant_value(n_steps) n_fixed_steps = opt.get_constant_value(n_steps)
except (TypeError, AttributeError): except tensor.basic.NotConstantError:
n_fixed_steps = None n_fixed_steps = None
# Check n_steps is an int # Check n_steps is an int
......
...@@ -463,17 +463,10 @@ def _allclose(a, b, rtol=None, atol=None): ...@@ -463,17 +463,10 @@ def _allclose(a, b, rtol=None, atol=None):
return numpy.allclose(a, b, atol=atol_, rtol=rtol_) return numpy.allclose(a, b, atol=atol_, rtol=rtol_)
class NotConstantError(TypeError): class NotConstantError(Exception):
""" """
Raised by get_constant_value if called on something that is Raised by get_constant_value if called on something that is
not constant. not constant.
For now it is a TypeError, to maintain the old interface
that get_constant_value should raise a TypeError in this
situation. However, this is unsafe because get_constant_value
could inadvertently raise a TypeError if it has a bug.
So we should eventually make NotConstantError derive
from Exception directly, and modify all code that uses
get_constant_value to catch this more specific exception.
""" """
pass pass
...@@ -2364,7 +2357,7 @@ class SpecifyShape(Op): ...@@ -2364,7 +2357,7 @@ class SpecifyShape(Op):
s = get_constant_value(node.inputs[1][dim]) s = get_constant_value(node.inputs[1][dim])
s = as_tensor_variable(s) s = as_tensor_variable(s)
new_shape.append(s) new_shape.append(s)
except TypeError: except NotConstantError:
new_shape.append(node.inputs[1][dim]) new_shape.append(node.inputs[1][dim])
assert len(new_shape) == len(xshape) assert len(new_shape) == len(xshape)
...@@ -2662,7 +2655,7 @@ def max(x, axis=None, keepdims=False): ...@@ -2662,7 +2655,7 @@ def max(x, axis=None, keepdims=False):
try: try:
const = get_constant_value(axis) const = get_constant_value(axis)
out = CAReduce(scal.maximum, list(const))(x) out = CAReduce(scal.maximum, list(const))(x)
except Exception: except NotConstantError:
out = max_and_argmax(x, axis)[0] out = max_and_argmax(x, axis)[0]
if keepdims: if keepdims:
...@@ -3272,7 +3265,7 @@ class Alloc(gof.Op): ...@@ -3272,7 +3265,7 @@ class Alloc(gof.Op):
# if s is constant 1, then we're broadcastable in that dim # if s is constant 1, then we're broadcastable in that dim
try: try:
const_shp = get_constant_value(s) const_shp = get_constant_value(s)
except TypeError: except NotConstantError:
const_shp = None const_shp = None
bcast.append(numpy.all(1 == const_shp)) bcast.append(numpy.all(1 == const_shp))
otype = TensorType(dtype=v.dtype, broadcastable=bcast) otype = TensorType(dtype=v.dtype, broadcastable=bcast)
...@@ -3820,7 +3813,7 @@ def extract_constant(x): ...@@ -3820,7 +3813,7 @@ def extract_constant(x):
''' '''
try: try:
x = get_constant_value(x) x = get_constant_value(x)
except Exception: except NotConstantError:
pass pass
if (isinstance(x, scal.ScalarVariable) or if (isinstance(x, scal.ScalarVariable) or
isinstance(x, scal.sharedvar.ScalarSharedVariable)): isinstance(x, scal.sharedvar.ScalarSharedVariable)):
...@@ -5423,7 +5416,7 @@ class Join(Op): ...@@ -5423,7 +5416,7 @@ class Join(Op):
# int # int
axis = int(get_constant_value(axis)) axis = int(get_constant_value(axis))
except TypeError: except NotConstantError:
pass pass
if isinstance(axis, int): if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the # Basically, broadcastable -> length 1, but the
...@@ -5792,7 +5785,7 @@ class Reshape(Op): ...@@ -5792,7 +5785,7 @@ class Reshape(Op):
try: try:
bcasts[index] = (hasattr(y, 'get_constant_value') and bcasts[index] = (hasattr(y, 'get_constant_value') and
y.get_constant_value() == 1) y.get_constant_value() == 1)
except TypeError: except NotConstantError:
pass pass
return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)]) return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)])
...@@ -5868,7 +5861,7 @@ class Reshape(Op): ...@@ -5868,7 +5861,7 @@ class Reshape(Op):
os_i = get_constant_value(node.inputs[1][i]).item() os_i = get_constant_value(node.inputs[1][i]).item()
if os_i == -1: if os_i == -1:
os_i = default_os_i os_i = default_os_i
except TypeError: except NotConstantError:
os_i = default_os_i os_i = default_os_i
oshape.append(os_i) oshape.append(os_i)
return [tuple(oshape)] return [tuple(oshape)]
...@@ -6150,7 +6143,7 @@ class ARange(Op): ...@@ -6150,7 +6143,7 @@ class ARange(Op):
try: try:
v = get_constant_value(var) v = get_constant_value(var)
return numpy.all(v == value) return numpy.all(v == value)
except Exception: except NotConstantError:
pass pass
return False return False
......
...@@ -138,7 +138,7 @@ def _is_1(expr): ...@@ -138,7 +138,7 @@ def _is_1(expr):
try: try:
v = opt.get_constant_value(expr) v = opt.get_constant_value(expr)
return numpy.allclose(v, 1) return numpy.allclose(v, 1)
except TypeError: except tensor.NotConstantError:
return False return False
log1msigm_to_softplus = gof.PatternSub( log1msigm_to_softplus = gof.PatternSub(
......
...@@ -33,7 +33,7 @@ from theano.gof.opt import (Optimizer, pre_constant_merge, ...@@ -33,7 +33,7 @@ from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer) pre_greedy_local_optimizer)
from theano.gof.opt import merge_optimizer from theano.gof.opt import merge_optimizer
from theano.gof import toolbox, DestroyHandler from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value, ShapeError from basic import get_constant_value, ShapeError, NotConstantError
theano.configparser.AddConfigVar('on_shape_error', theano.configparser.AddConfigVar('on_shape_error',
...@@ -95,7 +95,7 @@ def scalarconsts_rest(inputs): ...@@ -95,7 +95,7 @@ def scalarconsts_rest(inputs):
v = get_constant_value(i) v = get_constant_value(i)
consts.append(v) consts.append(v)
origconsts.append(i) origconsts.append(i)
except Exception: except NotConstantError:
nonconsts.append(i) nonconsts.append(i)
return consts, origconsts, nonconsts return consts, origconsts, nonconsts
...@@ -324,13 +324,13 @@ def local_0_dot_x(node): ...@@ -324,13 +324,13 @@ def local_0_dot_x(node):
try: try:
if get_constant_value(x) == 0: if get_constant_value(x) == 0:
replace = True replace = True
except TypeError: except NotConstantError:
pass pass
try: try:
if get_constant_value(y) == 0: if get_constant_value(y) == 0:
replace = True replace = True
except TypeError: except NotConstantError:
pass pass
if replace: if replace:
...@@ -1179,7 +1179,7 @@ def local_subtensor_make_vector(node): ...@@ -1179,7 +1179,7 @@ def local_subtensor_make_vector(node):
try: try:
v = get_constant_value(idx) v = get_constant_value(idx)
return [x.owner.inputs[v]] return [x.owner.inputs[v]]
except Exception: except NotConstantError:
pass pass
else: else:
# it is a slice of ints and/or Variables # it is a slice of ints and/or Variables
...@@ -1321,7 +1321,7 @@ def local_remove_useless_assert(node): ...@@ -1321,7 +1321,7 @@ def local_remove_useless_assert(node):
#Should we raise an error here? How to be sure it #Should we raise an error here? How to be sure it
#is not catched? #is not catched?
cond.append(c) cond.append(c)
except TypeError: except NotConstantError:
cond.append(c) cond.append(c)
if len(cond) == 0: if len(cond) == 0:
...@@ -1551,7 +1551,7 @@ def local_useless_subtensor(node): ...@@ -1551,7 +1551,7 @@ def local_useless_subtensor(node):
length_pos = shape_of[node.inputs[0]][pos] length_pos = shape_of[node.inputs[0]][pos]
try: try:
length_pos_data = get_constant_value(length_pos) length_pos_data = get_constant_value(length_pos)
except TypeError: except NotConstantError:
pass pass
if isinstance(idx.stop, int): if isinstance(idx.stop, int):
...@@ -2034,7 +2034,7 @@ def local_incsubtensor_of_allocs(node): ...@@ -2034,7 +2034,7 @@ def local_incsubtensor_of_allocs(node):
try: try:
if get_constant_value(y) == 0: if get_constant_value(y) == 0:
replace = True replace = True
except TypeError: except NotConstantError:
pass pass
if replace: if replace:
...@@ -2060,12 +2060,12 @@ def local_setsubtensor_of_allocs(node): ...@@ -2060,12 +2060,12 @@ def local_setsubtensor_of_allocs(node):
try: try:
replace_x = get_constant_value(x) replace_x = get_constant_value(x)
except TypeError: except NotConstantError:
pass pass
try: try:
replace_y = get_constant_value(y) replace_y = get_constant_value(y)
except TypeError: except NotConstantError:
pass pass
if (replace_x == replace_y and if (replace_x == replace_y and
...@@ -2260,7 +2260,7 @@ def local_mul_switch_sink(node): ...@@ -2260,7 +2260,7 @@ def local_mul_switch_sink(node):
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan 0].type.values_eq_approx_remove_nan
return fct return fct
except TypeError: except NotConstantError:
pass pass
try: try:
if get_constant_value(switch.inputs[2]) == 0.: if get_constant_value(switch.inputs[2]) == 0.:
...@@ -2270,7 +2270,7 @@ def local_mul_switch_sink(node): ...@@ -2270,7 +2270,7 @@ def local_mul_switch_sink(node):
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan 0].type.values_eq_approx_remove_nan
return fct return fct
except TypeError: except NotConstantError:
pass pass
return False return False
...@@ -2301,7 +2301,7 @@ def local_div_switch_sink(node): ...@@ -2301,7 +2301,7 @@ def local_div_switch_sink(node):
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan 0].type.values_eq_approx_remove_nan
return fct return fct
except TypeError: except NotConstantError:
pass pass
try: try:
if get_constant_value(switch.inputs[2]) == 0.: if get_constant_value(switch.inputs[2]) == 0.:
...@@ -2310,7 +2310,7 @@ def local_div_switch_sink(node): ...@@ -2310,7 +2310,7 @@ def local_div_switch_sink(node):
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan 0].type.values_eq_approx_remove_nan
return fct return fct
except TypeError: except NotConstantError:
pass pass
return False return False
...@@ -2703,7 +2703,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2703,7 +2703,7 @@ class Canonizer(gof.LocalOptimizer):
if isinstance(v, Variable): if isinstance(v, Variable):
try: try:
return get_constant_value(v) return get_constant_value(v)
except TypeError: except NotConstantError:
return None return None
else: else:
return v return v
...@@ -3208,7 +3208,7 @@ def local_sum_alloc(node): ...@@ -3208,7 +3208,7 @@ def local_sum_alloc(node):
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] * T.mul(*shapes) val = val.reshape(1)[0] * T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)] return [T.cast(val, dtype=node.outputs[0].dtype)]
except TypeError: except NotConstantError:
pass pass
else: else:
try: try:
...@@ -3222,7 +3222,7 @@ def local_sum_alloc(node): ...@@ -3222,7 +3222,7 @@ def local_sum_alloc(node):
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype), return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in xrange(len(shapes)) *[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])] if i not in node.op.axis])]
except TypeError: except NotConstantError:
pass pass
...@@ -3283,7 +3283,7 @@ def local_mul_zero(node): ...@@ -3283,7 +3283,7 @@ def local_mul_zero(node):
for i in node.inputs: for i in node.inputs:
try: try:
value = get_constant_value(i) value = get_constant_value(i)
except TypeError: except NotConstantError:
continue continue
#print 'MUL by value', value, node.inputs #print 'MUL by value', value, node.inputs
if N.all(value == 0): if N.all(value == 0):
...@@ -3521,7 +3521,7 @@ def local_add_specialize(node): ...@@ -3521,7 +3521,7 @@ def local_add_specialize(node):
for input in node.inputs: for input in node.inputs:
try: try:
y = get_constant_value(input) y = get_constant_value(input)
except TypeError: except NotConstantError:
y = input y = input
if numpy.all(y == 0.0): if numpy.all(y == 0.0):
continue continue
...@@ -3882,7 +3882,7 @@ def _is_1(expr): ...@@ -3882,7 +3882,7 @@ def _is_1(expr):
try: try:
v = get_constant_value(expr) v = get_constant_value(expr)
return numpy.allclose(v, 1) return numpy.allclose(v, 1)
except TypeError: except NotConstantError:
return False return False
...@@ -3892,7 +3892,7 @@ def _is_minus1(expr): ...@@ -3892,7 +3892,7 @@ def _is_minus1(expr):
try: try:
v = get_constant_value(expr) v = get_constant_value(expr)
return numpy.allclose(v, -1) return numpy.allclose(v, -1)
except TypeError: except NotConstantError:
return False return False
#1+erf(x)=>erfc(-x) #1+erf(x)=>erfc(-x)
...@@ -4133,7 +4133,7 @@ def local_grad_log_erfc_neg(node): ...@@ -4133,7 +4133,7 @@ def local_grad_log_erfc_neg(node):
try: try:
cst2 = get_constant_value(mul_neg.owner.inputs[0]) cst2 = get_constant_value(mul_neg.owner.inputs[0])
except TypeError: except NotConstantError:
return False return False
if len(mul_neg.owner.inputs) == 2: if len(mul_neg.owner.inputs) == 2:
...@@ -4160,7 +4160,7 @@ def local_grad_log_erfc_neg(node): ...@@ -4160,7 +4160,7 @@ def local_grad_log_erfc_neg(node):
x = erfc_x x = erfc_x
try: try:
cst = get_constant_value(erfc_x.owner.inputs[0]) cst = get_constant_value(erfc_x.owner.inputs[0])
except TypeError: except NotConstantError:
return False return False
if cst2 != -cst * 2: if cst2 != -cst * 2:
return False return False
......
...@@ -6176,7 +6176,7 @@ class T_get_constant_value(unittest.TestCase): ...@@ -6176,7 +6176,7 @@ class T_get_constant_value(unittest.TestCase):
b = tensor.iscalar() b = tensor.iscalar()
a = tensor.stack(b, 2, 3) a = tensor.stack(b, 2, 3)
self.assertRaises(TypeError, get_constant_value, a[0]) self.assertRaises(tensor.basic.NotConstantError, get_constant_value, a[0])
assert get_constant_value(a[1]) == 2 assert get_constant_value(a[1]) == 2
assert get_constant_value(a[2]) == 3 assert get_constant_value(a[2]) == 3
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论