提交 434dd96e authored 作者: Ian Goodfellow's avatar Ian Goodfellow

get rid of dangerous "TypeError = not constant" mechanism

上级 5237b952
......@@ -171,7 +171,8 @@ def get_constant_value(v):
If theano.sparse is also there, we will look over CSM op.
If `v` is not some view of constant data, then raise a TypeError.
If `v` is not some view of constant data, then raise a
tensor.basic.NotConstantError.
"""
if hasattr(theano, 'sparse') and isinstance(v.type,
theano.sparse.SparseType):
......
......@@ -1618,7 +1618,7 @@ def _is_zero(x):
try:
constant_value = theano.get_constant_value(x)
no_constant_value = False
except TypeError:
except theano.tensor.basic.NotConstantError:
pass
if no_constant_value:
......
......@@ -215,7 +215,7 @@ def is_positive(v):
if v.owner and v.owner.op == tensor.pow:
try:
exponent = tensor.get_constant_value(v.owner.inputs[1])
except TypeError:
except tensor.basic.NotConstantError:
return False
if 0 == exponent % 2:
return True
......
......@@ -136,7 +136,7 @@ def scan(fn,
else:
try:
n_fixed_steps = opt.get_constant_value(n_steps)
except (TypeError, AttributeError):
except tensor.basic.NotConstantError:
n_fixed_steps = None
# Check n_steps is an int
......
......@@ -364,7 +364,7 @@ def scan(fn,
else:
try:
n_fixed_steps = opt.get_constant_value(n_steps)
except (TypeError, AttributeError):
except tensor.basic.NotConstantError:
n_fixed_steps = None
# Check n_steps is an int
......
......@@ -463,17 +463,10 @@ def _allclose(a, b, rtol=None, atol=None):
return numpy.allclose(a, b, atol=atol_, rtol=rtol_)
class NotConstantError(TypeError):
class NotConstantError(Exception):
"""
Raised by get_constant_value if called on something that is
not constant.
For now it is a TypeError, to maintain the old interface
that get_constant_value should raise a TypeError in this
situation. However, this is unsafe because get_constant_value
could inadvertently raise a TypeError if it has a bug.
So we should eventually make NotConstantError derive
from Exception directly, and modify all code that uses
get_constant_value to catch this more specific exception.
"""
pass
......@@ -2364,7 +2357,7 @@ class SpecifyShape(Op):
s = get_constant_value(node.inputs[1][dim])
s = as_tensor_variable(s)
new_shape.append(s)
except TypeError:
except NotConstantError:
new_shape.append(node.inputs[1][dim])
assert len(new_shape) == len(xshape)
......@@ -2662,7 +2655,7 @@ def max(x, axis=None, keepdims=False):
try:
const = get_constant_value(axis)
out = CAReduce(scal.maximum, list(const))(x)
except Exception:
except NotConstantError:
out = max_and_argmax(x, axis)[0]
if keepdims:
......@@ -3272,7 +3265,7 @@ class Alloc(gof.Op):
# if s is constant 1, then we're broadcastable in that dim
try:
const_shp = get_constant_value(s)
except TypeError:
except NotConstantError:
const_shp = None
bcast.append(numpy.all(1 == const_shp))
otype = TensorType(dtype=v.dtype, broadcastable=bcast)
......@@ -3820,7 +3813,7 @@ def extract_constant(x):
'''
try:
x = get_constant_value(x)
except Exception:
except NotConstantError:
pass
if (isinstance(x, scal.ScalarVariable) or
isinstance(x, scal.sharedvar.ScalarSharedVariable)):
......@@ -5423,7 +5416,7 @@ class Join(Op):
# int
axis = int(get_constant_value(axis))
except TypeError:
except NotConstantError:
pass
if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the
......@@ -5792,7 +5785,7 @@ class Reshape(Op):
try:
bcasts[index] = (hasattr(y, 'get_constant_value') and
y.get_constant_value() == 1)
except TypeError:
except NotConstantError:
pass
return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)])
......@@ -5868,7 +5861,7 @@ class Reshape(Op):
os_i = get_constant_value(node.inputs[1][i]).item()
if os_i == -1:
os_i = default_os_i
except TypeError:
except NotConstantError:
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
......@@ -6150,7 +6143,7 @@ class ARange(Op):
try:
v = get_constant_value(var)
return numpy.all(v == value)
except Exception:
except NotConstantError:
pass
return False
......
......@@ -138,7 +138,7 @@ def _is_1(expr):
try:
v = opt.get_constant_value(expr)
return numpy.allclose(v, 1)
except TypeError:
except tensor.NotConstantError:
return False
log1msigm_to_softplus = gof.PatternSub(
......
......@@ -33,7 +33,7 @@ from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer)
from theano.gof.opt import merge_optimizer
from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value, ShapeError
from basic import get_constant_value, ShapeError, NotConstantError
theano.configparser.AddConfigVar('on_shape_error',
......@@ -95,7 +95,7 @@ def scalarconsts_rest(inputs):
v = get_constant_value(i)
consts.append(v)
origconsts.append(i)
except Exception:
except NotConstantError:
nonconsts.append(i)
return consts, origconsts, nonconsts
......@@ -324,13 +324,13 @@ def local_0_dot_x(node):
try:
if get_constant_value(x) == 0:
replace = True
except TypeError:
except NotConstantError:
pass
try:
if get_constant_value(y) == 0:
replace = True
except TypeError:
except NotConstantError:
pass
if replace:
......@@ -1179,7 +1179,7 @@ def local_subtensor_make_vector(node):
try:
v = get_constant_value(idx)
return [x.owner.inputs[v]]
except Exception:
except NotConstantError:
pass
else:
# it is a slice of ints and/or Variables
......@@ -1321,7 +1321,7 @@ def local_remove_useless_assert(node):
#Should we raise an error here? How to be sure it
#is not catched?
cond.append(c)
except TypeError:
except NotConstantError:
cond.append(c)
if len(cond) == 0:
......@@ -1551,7 +1551,7 @@ def local_useless_subtensor(node):
length_pos = shape_of[node.inputs[0]][pos]
try:
length_pos_data = get_constant_value(length_pos)
except TypeError:
except NotConstantError:
pass
if isinstance(idx.stop, int):
......@@ -2034,7 +2034,7 @@ def local_incsubtensor_of_allocs(node):
try:
if get_constant_value(y) == 0:
replace = True
except TypeError:
except NotConstantError:
pass
if replace:
......@@ -2060,12 +2060,12 @@ def local_setsubtensor_of_allocs(node):
try:
replace_x = get_constant_value(x)
except TypeError:
except NotConstantError:
pass
try:
replace_y = get_constant_value(y)
except TypeError:
except NotConstantError:
pass
if (replace_x == replace_y and
......@@ -2260,7 +2260,7 @@ def local_mul_switch_sink(node):
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except TypeError:
except NotConstantError:
pass
try:
if get_constant_value(switch.inputs[2]) == 0.:
......@@ -2270,7 +2270,7 @@ def local_mul_switch_sink(node):
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except TypeError:
except NotConstantError:
pass
return False
......@@ -2301,7 +2301,7 @@ def local_div_switch_sink(node):
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except TypeError:
except NotConstantError:
pass
try:
if get_constant_value(switch.inputs[2]) == 0.:
......@@ -2310,7 +2310,7 @@ def local_div_switch_sink(node):
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
return fct
except TypeError:
except NotConstantError:
pass
return False
......@@ -2703,7 +2703,7 @@ class Canonizer(gof.LocalOptimizer):
if isinstance(v, Variable):
try:
return get_constant_value(v)
except TypeError:
except NotConstantError:
return None
else:
return v
......@@ -3208,7 +3208,7 @@ def local_sum_alloc(node):
assert val.size == 1
val = val.reshape(1)[0] * T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)]
except TypeError:
except NotConstantError:
pass
else:
try:
......@@ -3222,7 +3222,7 @@ def local_sum_alloc(node):
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])]
except TypeError:
except NotConstantError:
pass
......@@ -3283,7 +3283,7 @@ def local_mul_zero(node):
for i in node.inputs:
try:
value = get_constant_value(i)
except TypeError:
except NotConstantError:
continue
#print 'MUL by value', value, node.inputs
if N.all(value == 0):
......@@ -3521,7 +3521,7 @@ def local_add_specialize(node):
for input in node.inputs:
try:
y = get_constant_value(input)
except TypeError:
except NotConstantError:
y = input
if numpy.all(y == 0.0):
continue
......@@ -3882,7 +3882,7 @@ def _is_1(expr):
try:
v = get_constant_value(expr)
return numpy.allclose(v, 1)
except TypeError:
except NotConstantError:
return False
......@@ -3892,7 +3892,7 @@ def _is_minus1(expr):
try:
v = get_constant_value(expr)
return numpy.allclose(v, -1)
except TypeError:
except NotConstantError:
return False
#1+erf(x)=>erfc(-x)
......@@ -4133,7 +4133,7 @@ def local_grad_log_erfc_neg(node):
try:
cst2 = get_constant_value(mul_neg.owner.inputs[0])
except TypeError:
except NotConstantError:
return False
if len(mul_neg.owner.inputs) == 2:
......@@ -4160,7 +4160,7 @@ def local_grad_log_erfc_neg(node):
x = erfc_x
try:
cst = get_constant_value(erfc_x.owner.inputs[0])
except TypeError:
except NotConstantError:
return False
if cst2 != -cst * 2:
return False
......
......@@ -6176,7 +6176,7 @@ class T_get_constant_value(unittest.TestCase):
b = tensor.iscalar()
a = tensor.stack(b, 2, 3)
self.assertRaises(TypeError, get_constant_value, a[0])
self.assertRaises(tensor.basic.NotConstantError, get_constant_value, a[0])
assert get_constant_value(a[1]) == 2
assert get_constant_value(a[2]) == 3
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论