提交 35d78258 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made zero check handle sparse types

上级 2c9aec23
......@@ -787,10 +787,39 @@ def _populate_grad_dict(var_to_node_to_idx,
# it's not undefined or disconnected
# The only other valid thing it can be is 0
is_zero = False
no_constant_value = True
try:
constant_value = tensor.get_constant_value(term)
except:
no_constant_value = False
except TypeError:
pass
extra_msg = ''
# The above won't work if it's a sparse type, handle sparse
# types here
if no_constant_value:
if isinstance(term.type, theano.sparse.SparseType):
if term.owner is not None and isinstance(term.owner.op,
theano.sparse.CSM):
data = term.owner.inputs[0]
try:
constant_value = tensor.get_constant_value(data)
no_constant_value = False
except TypeError:
print theano.printing.min_informative_str(data)
extra_msg += " It is a CSM, but its data isn't constant."
pass
else:
extra_msg += " It is a SparseType but theano doesn't know how"
extra_msg += " to turn it into a constant."
#end if CSM
else:
extra_msg += " It is not a SparseType."
#end if SparseType
#end if no_constant_value
if no_constant_value:
msg = "%s.grad returned %s of type %s for input"
msg += " %d. This input's only connections to "
msg += "the cost through this op are via "
......@@ -800,6 +829,7 @@ def _populate_grad_dict(var_to_node_to_idx,
msg += "DisconnectedType and theano can't "
msg += "simplify it to a constant, so it's not "
msg += "verifiably zeros."
msg += extra_msg
msg = msg % (str(node.op), str(term),
str(type(term)), i)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论