提交 fb3b6ee5 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add a test

上级 1396175a
...@@ -266,6 +266,38 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -266,6 +266,38 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
# TODO list override case # TODO list override case
@test_params
def test_connection_pattern_override(self, cls_ofg):
x, y = T.vectors('xy')
def f1(x, y):
del x
# but we know how to backpropagate for x for some reasons
# and we don't care about the gradient wrt y.
return y + T.round(y)
def f1_back(inputs, output_gradients):
return [
output_gradients[0],
theano.gradient.disconnected_type()]
op = cls_ofg(
inputs=[x, y],
outputs=[f1(x, y)],
grad_overrides=f1_back,
connection_pattern=[[True], [False]], # This is new
on_unused_input='ignore') # This is new
c = op(x, y)
g1 = theano.grad(c.sum(), x)
out = g1.eval({
x: np.ones((5,), dtype=np.float32),
y: np.ones((5,), dtype=np.float32)})
assert np.allclose(out, [1.] *5)
# array([ 1., 1., 1., 1., 1.])
@test_params @test_params
def test_nested(self, cls_ofg): def test_nested(self, cls_ofg):
x, y = T.vectors('xy') x, y = T.vectors('xy')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论