提交 5a523b34 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed check_for_bad_grad to support SymbolicOutput

上级 a1be5079
...@@ -285,6 +285,10 @@ def check_for_bad_grad( variables ): ...@@ -285,6 +285,10 @@ def check_for_bad_grad( variables ):
#implemented using a deque rather than recursion because python recursion #implemented using a deque rather than recursion because python recursion
#limit is set low by default #limit is set low by default
#handle the case where var is a theano.compile.io.SymbolicOutput
if hasattr(variables,'variable'):
variables = [ variables.variable ]
if not (isinstance(variables, list) or \ if not (isinstance(variables, list) or \
isinstance(variables, gof.Variable)): isinstance(variables, gof.Variable)):
raise TypeError("Expected gof.Variable or list thereof, got "+\ raise TypeError("Expected gof.Variable or list thereof, got "+\
...@@ -306,7 +310,12 @@ def check_for_bad_grad( variables ): ...@@ -306,7 +310,12 @@ def check_for_bad_grad( variables ):
if var not in already_checked: if var not in already_checked:
already_checked.update([var]) already_checked.update([var])
assert isinstance(var, gof.Variable) #handle the case where var is a theano.compile.io.SymbolicOutput
if hasattr(var, 'variable'):
var = var.variable
if not isinstance(var, gof.Variable):
raise TypeError("Expected gof.Variable, got "+str(type(var)))
node = var.owner node = var.owner
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论