提交 7c4d91b4 authored 作者: James Bergstra's avatar James Bergstra

ENH: BadGradFeature

上级 26dddb57
......@@ -194,6 +194,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
gmap[r] = g_r
return gmap
class BadGradOp(gof.Op):
"""
An Op representing a gradient that cannot be computed.
......@@ -239,6 +240,7 @@ class BadGradOp(gof.Op):
def raise_exc(self):
raise self.exc(self.msg)
class GradNotImplementedOp(BadGradOp):
""" A BadGradOp representing a gradient that hasn't been implemented yet.
"""
......@@ -261,6 +263,7 @@ class GradNotImplementedOp(BadGradOp):
"%s does not implement its gradient with respect to input %d" \
% (str(type(op)), x_pos))
def grad_not_implemented(op, x_pos, x):
"""
Return an un-computable symbolic variable of type `x.type`.
......@@ -274,59 +277,20 @@ def grad_not_implemented(op, x_pos, x):
return GradNotImplementedOp(op, x_pos)(x)
def check_for_bad_grad( variables ):
"""
variables: A gof.Variable or list thereof
Raises an exception if any of the variables represents
an expression involving a BadGradOp
"""
#implemented using a deque rather than recursion because python recursion
#limit is set low by default
#handle the case where var is a theano.compile.io.SymbolicOutput
if hasattr(variables,'variable'):
variables = [ variables.variable ]
if not (isinstance(variables, list) or \
isinstance(variables, gof.Variable)):
raise TypeError("Expected gof.Variable or list thereof, got "+\
str(type(variables)))
if not isinstance(variables,list):
variables = [ variables ]
vars_to_check = deque(variables)
already_checked = set([])
while True:
try:
var = vars_to_check.pop()
except IndexError:
break
if var not in already_checked:
already_checked.update([var])
#handle the case where var is a theano.compile.io.SymbolicOutput
if hasattr(var, 'variable'):
var = var.variable
def raise_if_bad_grad(node):
if node is not None:
op = node.op
if isinstance(op, BadGradOp):
op.raise_exc()
vars_to_check.extendleft(node.inputs)
if not isinstance(var, gof.Variable):
raise TypeError("Expected gof.Variable, got "+str(type(var)))
node = var.owner
class BadGradFeature(gof.Feature):
def on_import(self, fgraph, node):
raise_if_bad_grad(node)
if node is not None:
op = node.op
if isinstance(op, BadGradOp):
op.raise_exc()
vars_to_check.extendleft(node.inputs)
#end if node is not None
#end if not already_checked
#end while
theano.compile.function_module.std_fgraph.features.append(BadGradFeature)
########################
......@@ -648,7 +612,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
and ret[-1].name is None:
ret[-1].name = '(d%s/d%s)' % (cost.name, p.name)
check_for_bad_grad(ret)
# new_vars is meant to be a list of all variables created
# by this call to grad(), which will be visible to the caller
# after we return.
new_vars = graph.ancestors(ret,
blockers=graph.ancestors(cost) + list(wrt))
map(raise_if_bad_grad, [v.owner for v in new_vars])
return format_as(using_list, using_tuple, ret)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论