提交 dafaffaa authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix connection pattern

上级 703d535a
...@@ -1344,8 +1344,8 @@ class Scan(PureOp): ...@@ -1344,8 +1344,8 @@ class Scan(PureOp):
j_inp_idx = self.get_input_pos(jidx) + 1 j_inp_idx = self.get_input_pos(jidx) + 1
if connection_pattern[j_inp_idx][iidx] == True: if connection_pattern[j_inp_idx][iidx] == True:
for k in xrange(len(connection_pattern)): for k in xrange(len(connection_pattern)):
if connection_pattern[k][iidx]: if connection_pattern[k][jidx]:
connection_pattern[k][jidx] = True connection_pattern[k][iidx] = True
return connection_pattern return connection_pattern
### GRAD FUNCTION ### GRAD FUNCTION
......
...@@ -3310,6 +3310,18 @@ class T_Scan(unittest.TestCase): ...@@ -3310,6 +3310,18 @@ class T_Scan(unittest.TestCase):
# disconnected gradient a non disconnected type was returned # disconnected gradient a non disconnected type was returned
tensor.grad((m * m2).sum(), v) tensor.grad((m * m2).sum(), v)
def test_disconnected_gradient2(self):
v = tensor.vector('v')
m = tensor.matrix('m')
u0 = tensor.zeros((7,))
[u, m2], _ = theano.scan(lambda x, u: [x+u, u+v],
sequences=m,
outputs_info=[u0, None])
# This used to raise an exception with older versions becasue
# scan could not detect the connection between `m2` and `x`
tensor.grad(m2.sum(), m)
def test_pregreedy_optimizer(self): def test_pregreedy_optimizer(self):
W = tensor.zeros((5, 4)) W = tensor.zeros((5, 4))
bv = tensor.zeros((5,)) bv = tensor.zeros((5,))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论