提交 8e8758a3 authored 作者: khaotik's avatar khaotik

better handling for NullType and DisconnectedType

major changes: - The self._grad_op now only returns zeros_like() for special types like NullType() or DisconnectedType() - call to grad() will furthur replace returned zero tensors with special types - proposed gradient override interface : (single or list of below) Ellipsis -> <no_override> (-) since python 2 does not support `[...]` syntax, this may result in uglier code in python 2 None -> NullType() int(0) -> DisconnectedType() OpFromGraph instance or callable -> <override> minor changes: - various typo/bug fix notes: - This commit breaks OpFromGraph.R_op, which is expected to be fixed in upcoming commits.
上级 a6e5cd74
差异被折叠。
...@@ -160,7 +160,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -160,7 +160,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
w, b = T.vectors('wb') w, b = T.vectors('wb')
# we make the 3rd gradient default (no override) # we make the 3rd gradient default (no override)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2]) op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, Ellipsis])
xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb') xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
zz = T.sum(op_linear(xx, ww, bb)) zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb]) dx, dw, db = T.grad(zz, [xx, ww, bb])
...@@ -281,21 +281,19 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -281,21 +281,19 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
[True, False, True]] [True, False, True]]
assert results == expect_result assert results == expect_result
@test_params def test_infer_shape(self):
def test_infer_shape(self, cls_ofg): # test infer shape does not need to against inline case
# since the Op is remove during optimization phase
x = T.matrix('x') x = T.matrix('x')
y = T.matrix('y') y = T.matrix('y')
o1 = x + y o1 = x + y
o2 = x * y o2 = x * y
op_graph = cls_ofg([x, y], [o1, o2]) op_graph = OpFromGraph([x, y], [o1, o2])
q = T.matrix('q') q = T.matrix('q')
p = T.matrix('p') p = T.matrix('p')
# we don't want check_topo for inline ops
# since the inline op is replaced during optimization
self._compile_and_check([q, p], self._compile_and_check([q, p],
op_graph(q, p), op_graph(q, p),
[np.ones([3, 4], dtype=config.floatX), [np.ones([3, 4], dtype=config.floatX),
np.ones([3, 4], dtype=config.floatX)], np.ones([3, 4], dtype=config.floatX)],
cls_ofg, OpFromGraph)
check_topo=not op_graph.is_inline)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论