提交 4af5f8ad authored 作者: khaotik's avatar khaotik

make conn pat check for grad() safe for OfG

plus some minor code speedup in gradient.py
上级 88626f5d
...@@ -181,12 +181,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -181,12 +181,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
dx2, dw2, db2 = T.grad( dx2, dw2, db2 = T.grad(
zz2, [xx, ww, bb], zz2, [xx, ww, bb],
return_disconnected='Disconnected', return_disconnected='Disconnected',
disconnected_inputs='ignore',
null_gradients='return') null_gradients='return')
fn2 = function([xx, ww, bb], [dx2, dw2, db2]) assert isinstance(dx2.type, T.TensorType)
dxv2, dwv2, dbv2 = fn2(xv, wv, bv) assert dx2.ndim == 1
assert numpy.allclose(wv * 2, dxv) assert isinstance(dw2.type, NullType)
assert isinstance(dwv2.type, NullType) assert isinstance(db2.type, DisconnectedType)
assert isinstance(dbv2.type, DisconnectedType)
@test_params @test_params
def test_rop(self, cls_ofg): def test_rop(self, cls_ofg):
......
...@@ -1201,58 +1201,62 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1201,58 +1201,62 @@ def _populate_grad_dict(var_to_app_to_idx,
is_zero = _is_zero(term) is_zero = _is_zero(term)
assert is_zero in ['yes', 'no', 'maybe'] assert is_zero in ['yes', 'no', 'maybe']
if is_zero == 'maybe': if is_zero == 'maybe':
msg = "%s.grad returned %s of type %s for input" msg = "%s.grad returned %s of type %s for input" \
msg += " %d. This input's only connections to " " %d. This input's only connections to " \
msg += "the cost through this op are via " "the cost through this op are via " \
msg += "integer-valued outputs so it should be " "integer-valued outputs so it should be " \
msg += "NullType, DisconnectedType, or some form " "NullType, DisconnectedType, or some form " \
msg += "of zeros. It is not NullType or " "of zeros. It is not NullType or " \
msg += "DisconnectedType and theano can't " "DisconnectedType and theano can't " \
msg += "simplify it to a constant, so it's not " "simplify it to a constant, so it's not " \
msg += "verifiably zeros." "verifiably zeros."
msg = msg % (str(node.op), str(term), msg %= (node.op, term, type(term), i)
str(type(term)), i)
elif is_zero == 'no':
if is_zero == 'no': msg = "%s.grad returned %s of type %s for input" \
msg = "%s.grad returned %s of type %s for input" " %d. Since this input is only connected " \
msg += " %d. Since this input is only connected " "to integer-valued outputs, it should " \
msg += "to integer-valued outputs, it should " "evaluate to zeros, but it evaluates to" \
msg += "evaluate to zeros, but it evaluates to" "%s."
msg += "%s."
msg %= (node.op, term, type(term), i,
msg % (node.op, term, type(term), i,
theano.get_scalar_constant_value(term)) theano.get_scalar_constant_value(term))
raise ValueError(msg) raise ValueError(msg)
# Check that op.connection_pattern matches the connectivity # Check that op.connection_pattern matches the connectivity
# logic driving the op.grad method # logic driving the op.grad method
for i, packed in enumerate(zip(inputs, input_grads, for i, (ipt, ig, connected) in enumerate(
inputs_connected)): zip(inputs, input_grads, inputs_connected)):
ipt, ig, connected = packed
actually_connected = \ actually_connected = \
not isinstance(ig.type, DisconnectedType) not isinstance(ig.type, DisconnectedType)
if isinstance(node.op, theano.OpFromGraph):
ov = node.op._grad_op_overrides_l[i]
if ov is not None:
connected &= not isinstance(
ov.type, DisconnectedType)
if actually_connected and not connected: if actually_connected and not connected:
msg = "%s.grad returned %s of type %s for input %d." msg = "%s.grad returned %s of type %s for input %d." \
msg += " Expected DisconnectedType instance based on " " Expected DisconnectedType instance based on " \
msg += " the output of the op's connection_pattern " " the output of the op's connection_pattern " \
msg += "method." "method."
msg = msg % (str(node.op), str(ig), str(ig.type), i) msg %= (str(node.op), str(ig), str(ig.type), i)
raise TypeError(msg) raise TypeError(msg)
if connected and not actually_connected: elif connected and not actually_connected:
msg = "%s.grad returned DisconnectedType for input" msg = "%s.grad returned DisconnectedType for input" \
msg += " %d." " %d."
msg = msg % (str(node.op), i) msg = msg % (str(node.op), i)
if hasattr(node.op, 'connection_pattern'): if hasattr(node.op, 'connection_pattern'):
msg += ' Its connection_pattern method does not' msg += ' Its connection_pattern method does not' \
msg += ' allow this.' ' allow this.'
raise TypeError(msg) raise TypeError(msg)
else: else:
msg += ' You may want to implement a ' msg += ' You may want to implement a ' \
msg += 'connection_pattern method for it.' 'connection_pattern method for it.'
warnings.warn(msg) warnings.warn(msg)
# cache the result # cache the result
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论