提交 12d07fe1 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added some defensive programming

reduces # of cases where Rop succesfully runs on invalid input complete safety would require use of Assert op which disrupts optimizations, so not done for now
上级 1cbb3cb8
...@@ -20,6 +20,7 @@ from theano.gof import Apply, Constant, Op, Type, Value, Variable ...@@ -20,6 +20,7 @@ from theano.gof import Apply, Constant, Op, Type, Value, Variable
import elemwise import elemwise
from theano import scalar as scal from theano import scalar as scal
from theano.gof.python25 import partial, any, all from theano.gof.python25 import partial, any, all
from theano.gof.op import get_test_value
from theano import compile, printing from theano import compile, printing
from theano.printing import pprint, min_informative_str from theano.printing import pprint, min_informative_str
...@@ -5216,8 +5217,46 @@ class Dot(Op): ...@@ -5216,8 +5217,46 @@ class Dot(Op):
# simply c \dot b + a \dot d # simply c \dot b + a \dot d
if None in eval_points: if None in eval_points:
return [None] return [None]
t1 = self.make_node(eval_points[0], inputs[1]).outputs[0]
t2 = self.make_node(inputs[0], eval_points[1]).outputs[0] assert len(inputs) == 2
assert len(eval_points) == 2
debugger_available = config.compute_test_value != 'off'
if debugger_available:
try:
iv0 = get_test_value(inputs[0])
except AttributeError:
raise AttributeError('first input passed to Dot.R_op has no test value')
try:
iv1 = get_test_value(inputs[1])
except AttributeError:
raise AttributeError('second input passed to Dot.R_op has no test value')
try:
ev0 = get_test_value(eval_points[0])
except AttributeError:
raise AttributeError('first eval point passed to Dot.R_op has no test value')
try:
ev1 = get_test_value(eval_points[1])
except AttributeError:
raise AttributeError('second eval point passed to Dot.R_op has no test value')
input_values = [ iv0, iv1]
eval_point_values = [ ev0, ev1 ]
for i in xrange(2):
if input_values[i].shape != eval_point_values[i].shape:
raise ValueError('input '+str(i)+' and eval_point '+str(i)+' to Dot.R_op '
'should have the '
'same shape, but their shapes are %s and %s, respectively' % ( \
str(input_values[i].shape), str(eval_point_values[i]) ) )
t1 = self(eval_points[0], inputs[1])
t2 = self(inputs[0], eval_points[1])
return [t1+t2] return [t1+t2]
......
...@@ -64,6 +64,23 @@ def Rop(f, wrt, eval_points): ...@@ -64,6 +64,23 @@ def Rop(f, wrt, eval_points):
assert len(wrt) == len(eval_points) assert len(wrt) == len(eval_points)
for pack in enumerate(zip(wrt, eval_points)):
i = pack[0]
wrt_elem, eval_point = pack[1]
wrt_elem = as_tensor_variable(wrt_elem)
eval_point = as_tensor_variable(eval_point)
wrt_dim = len(wrt_elem.type.broadcastable)
eval_dim = len(eval_point.type.broadcastable)
if wrt_dim != eval_dim:
from theano.printing import min_informative_str
print min_informative_str(wrt_elem)
print min_informative_str(eval_point)
raise ValueError('Element '+str(i)+' of wrt/eval_point have mismatched '
'dimensionality: '+str(wrt_dim)+' versus '+str(eval_dim))
seen_nodes = {} seen_nodes = {}
def _traverse(node): def _traverse(node):
......
...@@ -261,3 +261,14 @@ class test_RopLop(unittest.TestCase): ...@@ -261,3 +261,14 @@ class test_RopLop(unittest.TestCase):
self.mat_in_shape[0], self.mat_in_shape[1], self.in_shape[0]) self.mat_in_shape[0], self.mat_in_shape[1], self.in_shape[0])
self.check_rop_lop(out3d.flatten(), self.check_rop_lop(out3d.flatten(),
self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0]) self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0])
def test_invalid_input(self):
success = False
try:
tensor.Rop(0., [ tensor.matrix() ], [ tensor.vector() ] )
success = True
except ValueError:
pass
assert not success
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论