提交 9d186f9f authored 作者: delallea's avatar delallea

Merge pull request #60 from goodfeli/Rop_defense

added some defensive programming
...@@ -567,3 +567,23 @@ def get_test_value(v): ...@@ -567,3 +567,23 @@ def get_test_value(v):
v_tensor = theano.tensor.as_tensor_variable(v) v_tensor = theano.tensor.as_tensor_variable(v)
return PureOp._get_test_value(v_tensor) return PureOp._get_test_value(v_tensor)
def missing_test_message(msg):
""" Displays msg, a message saying that some test_value is missing,
in the appropriate form based on config.compute_test_value:
off: the interactive debugger is off, so we do nothing
ignore: the interactive debugger is set to ignore missing inputs,
so do nothing
warn: display msg as a warning
raise: raise an AttributeError with msg as the exception text
"""
action = config.compute_test_value
if action == 'raise':
raise AttributeError(msg)
elif action == 'warn':
warnings.warn(msg, stacklevel = 2)
else:
assert action in [ 'ignore', 'off' ]
...@@ -20,6 +20,7 @@ from theano.gof import Apply, Constant, Op, Type, Value, Variable ...@@ -20,6 +20,7 @@ from theano.gof import Apply, Constant, Op, Type, Value, Variable
import elemwise import elemwise
from theano import scalar as scal from theano import scalar as scal
from theano.gof.python25 import partial, any, all from theano.gof.python25 import partial, any, all
from theano.gof.op import get_test_value, missing_test_message
from theano import compile, printing from theano import compile, printing
from theano.printing import pprint, min_informative_str from theano.printing import pprint, min_informative_str
...@@ -5216,8 +5217,50 @@ class Dot(Op): ...@@ -5216,8 +5217,50 @@ class Dot(Op):
# simply c \dot b + a \dot d # simply c \dot b + a \dot d
if None in eval_points: if None in eval_points:
return [None] return [None]
t1 = self.make_node(eval_points[0], inputs[1]).outputs[0]
t2 = self.make_node(inputs[0], eval_points[1]).outputs[0] assert len(inputs) == 2
assert len(eval_points) == 2
debugger_available = config.compute_test_value != 'off'
if debugger_available:
try:
iv0 = get_test_value(inputs[0])
except AttributeError:
missing_test_message('first input passed to Dot.R_op has no test value')
debugger_available = False
try:
iv1 = get_test_value(inputs[1])
except AttributeError:
missing_test_message('second input passed to Dot.R_op has no test value')
debugger_available = False
try:
ev0 = get_test_value(eval_points[0])
except AttributeError:
missing_test_message('first eval point passed to Dot.R_op has no test value')
debugger_available = False
try:
ev1 = get_test_value(eval_points[1])
except AttributeError:
missing_test_message('second eval point passed to Dot.R_op has no test value')
debugger_available = False
if debugger_available:
input_values = [ iv0, iv1]
eval_point_values = [ ev0, ev1 ]
for i in xrange(2):
if input_values[i].shape != eval_point_values[i].shape:
raise ValueError('input '+str(i)+' and eval_point '+str(i)+' to Dot.R_op '
'should have the '
'same shape, but their shapes are %s and %s, respectively' % ( \
str(input_values[i].shape), str(eval_point_values[i].shape) ) )
t1 = self(eval_points[0], inputs[1])
t2 = self(inputs[0], eval_points[1])
return [t1+t2] return [t1+t2]
......
...@@ -64,6 +64,20 @@ def Rop(f, wrt, eval_points): ...@@ -64,6 +64,20 @@ def Rop(f, wrt, eval_points):
assert len(wrt) == len(eval_points) assert len(wrt) == len(eval_points)
for pack in enumerate(zip(wrt, eval_points)):
i = pack[0]
wrt_elem, eval_point = pack[1]
wrt_elem = as_tensor_variable(wrt_elem)
eval_point = as_tensor_variable(eval_point)
wrt_dim = len(wrt_elem.type.broadcastable)
eval_dim = len(eval_point.type.broadcastable)
if wrt_dim != eval_dim:
raise ValueError('Element '+str(i)+' of wrt/eval_point have mismatched '
'dimensionality: '+str(wrt_dim)+' versus '+str(eval_dim))
seen_nodes = {} seen_nodes = {}
def _traverse(node): def _traverse(node):
......
...@@ -261,3 +261,14 @@ class test_RopLop(unittest.TestCase): ...@@ -261,3 +261,14 @@ class test_RopLop(unittest.TestCase):
self.mat_in_shape[0], self.mat_in_shape[1], self.in_shape[0]) self.mat_in_shape[0], self.mat_in_shape[1], self.in_shape[0])
self.check_rop_lop(out3d.flatten(), self.check_rop_lop(out3d.flatten(),
self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0]) self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0])
def test_invalid_input(self):
success = False
try:
tensor.Rop(0., [ tensor.matrix() ], [ tensor.vector() ] )
success = True
except ValueError:
pass
assert not success
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论