提交 12d07fe1 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added some defensive programming

reduces # of cases where Rop succesfully runs on invalid input complete safety would require use of Assert op which disrupts optimizations, so not done for now
上级 1cbb3cb8
......@@ -20,6 +20,7 @@ from theano.gof import Apply, Constant, Op, Type, Value, Variable
import elemwise
from theano import scalar as scal
from theano.gof.python25 import partial, any, all
from theano.gof.op import get_test_value
from theano import compile, printing
from theano.printing import pprint, min_informative_str
......@@ -5216,8 +5217,46 @@ class Dot(Op):
# simply c \dot b + a \dot d
if None in eval_points:
return [None]
t1 = self.make_node(eval_points[0], inputs[1]).outputs[0]
t2 = self.make_node(inputs[0], eval_points[1]).outputs[0]
assert len(inputs) == 2
assert len(eval_points) == 2
debugger_available = config.compute_test_value != 'off'
if debugger_available:
try:
iv0 = get_test_value(inputs[0])
except AttributeError:
raise AttributeError('first input passed to Dot.R_op has no test value')
try:
iv1 = get_test_value(inputs[1])
except AttributeError:
raise AttributeError('second input passed to Dot.R_op has no test value')
try:
ev0 = get_test_value(eval_points[0])
except AttributeError:
raise AttributeError('first eval point passed to Dot.R_op has no test value')
try:
ev1 = get_test_value(eval_points[1])
except AttributeError:
raise AttributeError('second eval point passed to Dot.R_op has no test value')
input_values = [ iv0, iv1]
eval_point_values = [ ev0, ev1 ]
for i in xrange(2):
if input_values[i].shape != eval_point_values[i].shape:
raise ValueError('input '+str(i)+' and eval_point '+str(i)+' to Dot.R_op '
'should have the '
'same shape, but their shapes are %s and %s, respectively' % ( \
str(input_values[i].shape), str(eval_point_values[i]) ) )
t1 = self(eval_points[0], inputs[1])
t2 = self(inputs[0], eval_points[1])
return [t1+t2]
......
......@@ -64,6 +64,23 @@ def Rop(f, wrt, eval_points):
assert len(wrt) == len(eval_points)
for pack in enumerate(zip(wrt, eval_points)):
i = pack[0]
wrt_elem, eval_point = pack[1]
wrt_elem = as_tensor_variable(wrt_elem)
eval_point = as_tensor_variable(eval_point)
wrt_dim = len(wrt_elem.type.broadcastable)
eval_dim = len(eval_point.type.broadcastable)
if wrt_dim != eval_dim:
from theano.printing import min_informative_str
print min_informative_str(wrt_elem)
print min_informative_str(eval_point)
raise ValueError('Element '+str(i)+' of wrt/eval_point have mismatched '
'dimensionality: '+str(wrt_dim)+' versus '+str(eval_dim))
seen_nodes = {}
def _traverse(node):
......
......@@ -261,3 +261,14 @@ class test_RopLop(unittest.TestCase):
self.mat_in_shape[0], self.mat_in_shape[1], self.in_shape[0])
self.check_rop_lop(out3d.flatten(),
self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0])
def test_invalid_input(self):
success = False
try:
tensor.Rop(0., [ tensor.matrix() ], [ tensor.vector() ] )
success = True
except ValueError:
pass
assert not success
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论