提交 9b81dfc5 authored 作者: Frederic Bastien's avatar Frederic Bastien

Reuse pre-created object

上级 a6773aaf
......@@ -146,6 +146,7 @@ class DisconnectedType(theano.gof.type.Type):
def __str__(self):
return 'DisconnectedType'
disconnected_type = DisconnectedType()
########################
......@@ -524,7 +525,7 @@ def grad(cost, wrt, consider_constant=None,
if elem not in var_to_app_to_idx and elem is not cost \
and elem not in grad_dict:
handle_disconnected(elem)
grad_dict[elem] = DisconnectedType()()
grad_dict[elem] = disconnected_type()
cost_name = None
if add_names and cost is not None:
......@@ -978,7 +979,7 @@ def _populate_grad_dict(var_to_app_to_idx,
# are disconnected
# (The op's grad method could do this too, but this saves the
# implementer the trouble of worrying about this case)
input_grads = [DisconnectedType()() for ipt in inputs]
input_grads = [disconnected_type() for ipt in inputs]
elif False not in only_connected_to_nan:
# All inputs are only connected to nan gradients, so we don't
# need to bother calling the grad method. We know the gradient
......@@ -988,7 +989,7 @@ def _populate_grad_dict(var_to_app_to_idx,
if connected:
input_grads.append(NullType()())
else:
input_grads.append(DisconnectedType()())
input_grads.append(disconnected_type())
else:
# At least one input of this op is connected to the cost so and
# not all output gradients are undefined so we must
......@@ -1124,7 +1125,7 @@ def _populate_grad_dict(var_to_app_to_idx,
raise TypeError(('%s.grad returned None for' +
' a gradient term, '
'this is prohibited. Instead of None,'
'return zeros_like(input), DisconnectedType()(),'
'return zeros_like(input), disconnected_type(),'
' or a NullType variable such as those made with '
'the grad_undefined or grad_unimplemented helper '
'functions.') % node.op)
......@@ -1258,14 +1259,14 @@ def _populate_grad_dict(var_to_app_to_idx,
# extraneous TensorConstant(0)
grad_dict[var] = reduce(lambda x, y: x + y, terms)
else:
grad_dict[var] = DisconnectedType()()
grad_dict[var] = disconnected_type()
if cost_name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name)
else:
# this variable isn't connected to the cost in the
# computational graph
grad_dict[var] = DisconnectedType()()
grad_dict[var] = disconnected_type()
# end if cache miss
return grad_dict[var]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论