提交 417067e1 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

function required to have fast gradient computations

上级 05e72586
......@@ -930,3 +930,34 @@ class scan_args(object):
'mit_sot_in_slices')):
getattr(res, attr).extend(getattr(other, attr))
return res
def forced_replace(out, x, y):
"""
:param out: Theano Variable
:param x: Theano Variable
:param y: Theano Variable
This function checks all internal values of the graph that computes the
variable ``out`` for occurances of values identical with ``x``. If such
occurances are encountered then they are replaced with variable ``y``.
For example:
out := sigmoid(wu)*(1-sigmoid(wu))
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
"""
if out is None:
return None
def traverse(graph, x):
if equal_computations([graph], [x]):
return [graph]
elif not graph.owner:
return []
else:
rval = []
for inp in graph.owner.inputs:
rval += traverse(inp, x)
return rval
to_replace = traverse(out, x)
return clone(out, replace=dict((v, y) for v in to_replace))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论