提交 4af5f8ad authored 作者: khaotik's avatar khaotik

make conn pat check for grad() safe for OfG

plus some minor code speedup in gradient.py
上级 88626f5d
......@@ -181,12 +181,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
dx2, dw2, db2 = T.grad(
zz2, [xx, ww, bb],
return_disconnected='Disconnected',
disconnected_inputs='ignore',
null_gradients='return')
fn2 = function([xx, ww, bb], [dx2, dw2, db2])
dxv2, dwv2, dbv2 = fn2(xv, wv, bv)
assert numpy.allclose(wv * 2, dxv)
assert isinstance(dwv2.type, NullType)
assert isinstance(dbv2.type, DisconnectedType)
assert isinstance(dx2.type, T.TensorType)
assert dx2.ndim == 1
assert isinstance(dw2.type, NullType)
assert isinstance(db2.type, DisconnectedType)
@test_params
def test_rop(self, cls_ofg):
......
......@@ -1201,58 +1201,62 @@ def _populate_grad_dict(var_to_app_to_idx,
is_zero = _is_zero(term)
assert is_zero in ['yes', 'no', 'maybe']
if is_zero == 'maybe':
msg = "%s.grad returned %s of type %s for input"
msg += " %d. This input's only connections to "
msg += "the cost through this op are via "
msg += "integer-valued outputs so it should be "
msg += "NullType, DisconnectedType, or some form "
msg += "of zeros. It is not NullType or "
msg += "DisconnectedType and theano can't "
msg += "simplify it to a constant, so it's not "
msg += "verifiably zeros."
msg = msg % (str(node.op), str(term),
str(type(term)), i)
if is_zero == 'no':
msg = "%s.grad returned %s of type %s for input"
msg += " %d. Since this input is only connected "
msg += "to integer-valued outputs, it should "
msg += "evaluate to zeros, but it evaluates to"
msg += "%s."
msg % (node.op, term, type(term), i,
msg = "%s.grad returned %s of type %s for input" \
" %d. This input's only connections to " \
"the cost through this op are via " \
"integer-valued outputs so it should be " \
"NullType, DisconnectedType, or some form " \
"of zeros. It is not NullType or " \
"DisconnectedType and theano can't " \
"simplify it to a constant, so it's not " \
"verifiably zeros."
msg %= (node.op, term, type(term), i)
elif is_zero == 'no':
msg = "%s.grad returned %s of type %s for input" \
" %d. Since this input is only connected " \
"to integer-valued outputs, it should " \
"evaluate to zeros, but it evaluates to" \
"%s."
msg %= (node.op, term, type(term), i,
theano.get_scalar_constant_value(term))
raise ValueError(msg)
# Check that op.connection_pattern matches the connectivity
# logic driving the op.grad method
for i, packed in enumerate(zip(inputs, input_grads,
inputs_connected)):
ipt, ig, connected = packed
for i, (ipt, ig, connected) in enumerate(
zip(inputs, input_grads, inputs_connected)):
actually_connected = \
not isinstance(ig.type, DisconnectedType)
if isinstance(node.op, theano.OpFromGraph):
ov = node.op._grad_op_overrides_l[i]
if ov is not None:
connected &= not isinstance(
ov.type, DisconnectedType)
if actually_connected and not connected:
msg = "%s.grad returned %s of type %s for input %d."
msg += " Expected DisconnectedType instance based on "
msg += " the output of the op's connection_pattern "
msg += "method."
msg = msg % (str(node.op), str(ig), str(ig.type), i)
msg = "%s.grad returned %s of type %s for input %d." \
" Expected DisconnectedType instance based on " \
" the output of the op's connection_pattern " \
"method."
msg %= (str(node.op), str(ig), str(ig.type), i)
raise TypeError(msg)
if connected and not actually_connected:
msg = "%s.grad returned DisconnectedType for input"
msg += " %d."
elif connected and not actually_connected:
msg = "%s.grad returned DisconnectedType for input" \
" %d."
msg = msg % (str(node.op), i)
if hasattr(node.op, 'connection_pattern'):
msg += ' Its connection_pattern method does not'
msg += ' allow this.'
msg += ' Its connection_pattern method does not' \
' allow this.'
raise TypeError(msg)
else:
msg += ' You may want to implement a '
msg += 'connection_pattern method for it.'
msg += ' You may want to implement a ' \
'connection_pattern method for it.'
warnings.warn(msg)
# cache the result
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论