提交 dabbc4e2 authored 作者: Frederic's avatar Frederic

Cache Scan.connection_pattern(). Hopefully, it will speed up theano.grad.

上级 116f70fe
......@@ -1285,6 +1285,11 @@ class Scan(PureOp):
return ipos + opos
def connection_pattern(self, node):
# We cache this, as grad call connection_pattern, and it call
# grad in its turn. I was a case where theano.grad() took 4h
# that had many scan one inside each others.
if hasattr(node.tag, 'connection_pattern'):
return node.tag.connection_pattern
# The gradient wrt to n_steps is disconnected
connection_pattern = [[False for output in node.outputs]]
connection_pattern += [[False for output in node.outputs]
......@@ -1391,6 +1396,8 @@ class Scan(PureOp):
for k in xrange(len(connection_pattern)):
if connection_pattern[k][jidx]:
connection_pattern[k][iidx] = True
node.tag.connection_pattern = connection_pattern
return connection_pattern
### GRAD FUNCTION
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论