提交 ea66f4a7 authored 作者: khaotik's avatar khaotik

OpFromGraph improvements

- now use explicit NullType() or DisconnectedType() instance for OpFromGraph special types. 'default' for no gradient/Rop override - connection pattern now considers gradient override - revert hackish changes done to theano/gradient.py - bug fix for theano.gof.graph.io_connection_pattern with isolated output - bug fix when _recompute_grad/rop_op is called twice for OfG instance
上级 b5685f95
差异被折叠。
......@@ -162,7 +162,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
w, b = T.vectors('wb')
# we make the 3rd gradient default (no override)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, Ellipsis])
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, 'default'])
xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb])
......@@ -176,7 +176,9 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.allclose(np.ones(16, dtype=config.floatX), dbv)
# NullType and DisconnectedType
op_linear2 = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, None, 0])
op_linear2 = cls_ofg(
[x, w, b], [x * w + b],
grad_overrides=[go1, NullType()(), DisconnectedType()()])
zz2 = T.sum(op_linear2(xx, ww, bb))
dx2, dw2, db2 = T.grad(
zz2, [xx, ww, bb],
......@@ -205,7 +207,6 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
duval = np.random.rand(16).astype(config.floatX)
dvval = np.dot(duval, Wval)
dvval2 = fn(xval, Wval, duval)
print(dvval, dvval2)
assert np.allclose(dvval2, dvval)
@test_params
......
......@@ -1099,7 +1099,10 @@ def io_connection_pattern(inputs, outputs):
# connnection patterns of the individual outputs
global_connection_pattern = [[] for o in range(len(inputs))]
for out in outputs:
out_connection_pattern = connect_pattern_by_var[out]
out_connection_pattern = connect_pattern_by_var.get(out)
if out_connection_pattern is None:
# the output is completely isolated from inputs
out_connection_pattern = [False] * len(inputs)
for i in range(len(inputs)):
global_connection_pattern[i].append(out_connection_pattern[i])
......
......@@ -1233,12 +1233,6 @@ def _populate_grad_dict(var_to_app_to_idx,
actually_connected = \
not isinstance(ig.type, DisconnectedType)
if isinstance(node.op, theano.OpFromGraph):
ov = node.op._grad_op_overrides_l[i]
if ov is not None:
connected &= not isinstance(
ov.type, DisconnectedType)
if actually_connected and not connected:
msg = ("%s.grad returned %s of type %s for input %d."
" Expected DisconnectedType instance based on "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论