提交 90aefc28 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

changed gradient.grad to not introduce integer-valued gradients

上级 b115fbc9
......@@ -444,7 +444,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
grad_dict = {}
# by default, the gradient of the cost is 1
if g_cost is None:
g_cost = tensor.ones_like(cost)
g_cost = _float_ones_like(cost)
grad_dict[cost] = g_cost
# the gradient of the constants is 0
......@@ -477,12 +477,18 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
if add_names:
cost_name = cost.name
# Make sure we didn't initialize the grad_dict with any ints
for var in grad_dict:
g = grad_dict[var]
if hasattr(g.type,'dtype'):
assert g.type.dtype.find('float') != -1
rval = _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name)
for i in xrange(len(rval)):
if isinstance(rval[i].type, DisconnectedType):
rval[i] = wrt[i].zeros_like()
rval[i] = _float_zeros_like(wrt[i])
if using_tuple:
rval = tuple(rval)
......@@ -931,10 +937,31 @@ def grad_sources_inputs(sources, graph_inputs):
for key in grad_dict:
if isinstance(grad_dict[key].type, DisconnectedType):
if hasattr(key, 'zeros_like'):
grad_dict[key] = key.zeros_like()
grad_dict[key] = _float_zeros_like(key)
return grad_dict
def _float_zeros_like(x):
""" Like zeros_like, but forces the object to have a
a floating point dtype """
rval = x.zeros_like()
if rval.type.dtype.find('float') != -1:
return rval
return rval.astype(theano.config.floatX)
def _float_ones_like(x):
""" Like ones_like, but forces the object to have a
floating point dtype """
rval = tensor.ones_like(x)
if rval.type.dtype.find('float') != -1:
return rval
return rval.astype(theano.config.floatX)
class numeric_grad(object):
"""
......
......@@ -1910,6 +1910,7 @@ class TensorFromScalar(Op):
def grad(self, inp, grads):
s, = inp
dt, = grads
assert dt.type.dtype.find('float') != -1
return [scalar_from_tensor(dt)]
def __str__(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论