提交 08a58f7b authored 作者: Frederic's avatar Frederic

make theano.get_constant_value() to remove theano/gradient.py dependance on the sparse packages.

上级 42b8b66e
...@@ -161,3 +161,22 @@ def dot(l, r): ...@@ -161,3 +161,22 @@ def dot(l, r):
raise NotImplementedError("Dot failed for the following reasons:", raise NotImplementedError("Dot failed for the following reasons:",
(e0, e1)) (e0, e1))
return rval return rval
def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
this function digs through them.
If theano.sparse is also there, we will look over CSM op.
If `v` is not some view of constant data, then raise a TypeError.
"""
if hasattr(theano, 'sparse') and isinstance(v.type,
theano.sparse.SparseType):
if v.owner is not None and isinstance(v.owner.op,
theano.sparse.CSM):
data = v.owner.inputs[0]
return tensor.get_constant_value(data)
return tensor.get_constant_value(v)
...@@ -804,36 +804,11 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -804,36 +804,11 @@ def _populate_grad_dict(var_to_node_to_idx,
no_constant_value = True no_constant_value = True
try: try:
constant_value = tensor.get_constant_value(term) constant_value = theano.get_constant_value(term)
no_constant_value = False no_constant_value = False
except TypeError: except TypeError:
pass pass
extra_msg = ''
# The above won't work if it's a sparse type, handle sparse
# types here
if no_constant_value:
if isinstance(term.type, theano.sparse.SparseType):
if term.owner is not None and isinstance(term.owner.op,
theano.sparse.CSM):
data = term.owner.inputs[0]
try:
constant_value = tensor.get_constant_value(data)
no_constant_value = False
except TypeError:
print theano.printing.min_informative_str(data)
extra_msg += " It is a CSM, but its data isn't constant."
pass
else:
extra_msg += " It is a SparseType but theano doesn't know how"
extra_msg += " to turn it into a constant."
#end if CSM
else:
extra_msg += " It is not a SparseType."
#end if SparseType
#end if no_constant_value
if no_constant_value: if no_constant_value:
msg = "%s.grad returned %s of type %s for input" msg = "%s.grad returned %s of type %s for input"
msg += " %d. This input's only connections to " msg += " %d. This input's only connections to "
...@@ -844,7 +819,6 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -844,7 +819,6 @@ def _populate_grad_dict(var_to_node_to_idx,
msg += "DisconnectedType and theano can't " msg += "DisconnectedType and theano can't "
msg += "simplify it to a constant, so it's not " msg += "simplify it to a constant, so it's not "
msg += "verifiably zeros." msg += "verifiably zeros."
msg += extra_msg
msg = msg % (str(node.op), str(term), msg = msg % (str(node.op), str(term),
str(type(term)), i) str(type(term)), i)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论