提交 9b81dfc5 authored 作者: Frederic Bastien's avatar Frederic Bastien

Reuse pre-created object

上级 a6773aaf
...@@ -146,6 +146,7 @@ class DisconnectedType(theano.gof.type.Type): ...@@ -146,6 +146,7 @@ class DisconnectedType(theano.gof.type.Type):
def __str__(self): def __str__(self):
return 'DisconnectedType' return 'DisconnectedType'
disconnected_type = DisconnectedType()
######################## ########################
...@@ -524,7 +525,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -524,7 +525,7 @@ def grad(cost, wrt, consider_constant=None,
if elem not in var_to_app_to_idx and elem is not cost \ if elem not in var_to_app_to_idx and elem is not cost \
and elem not in grad_dict: and elem not in grad_dict:
handle_disconnected(elem) handle_disconnected(elem)
grad_dict[elem] = DisconnectedType()() grad_dict[elem] = disconnected_type()
cost_name = None cost_name = None
if add_names and cost is not None: if add_names and cost is not None:
...@@ -978,7 +979,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -978,7 +979,7 @@ def _populate_grad_dict(var_to_app_to_idx,
# are disconnected # are disconnected
# (The op's grad method could do this too, but this saves the # (The op's grad method could do this too, but this saves the
# implementer the trouble of worrying about this case) # implementer the trouble of worrying about this case)
input_grads = [DisconnectedType()() for ipt in inputs] input_grads = [disconnected_type() for ipt in inputs]
elif False not in only_connected_to_nan: elif False not in only_connected_to_nan:
# All inputs are only connected to nan gradients, so we don't # All inputs are only connected to nan gradients, so we don't
# need to bother calling the grad method. We know the gradient # need to bother calling the grad method. We know the gradient
...@@ -988,7 +989,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -988,7 +989,7 @@ def _populate_grad_dict(var_to_app_to_idx,
if connected: if connected:
input_grads.append(NullType()()) input_grads.append(NullType()())
else: else:
input_grads.append(DisconnectedType()()) input_grads.append(disconnected_type())
else: else:
# At least one input of this op is connected to the cost so and # At least one input of this op is connected to the cost so and
# not all output gradients are undefined so we must # not all output gradients are undefined so we must
...@@ -1124,7 +1125,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1124,7 +1125,7 @@ def _populate_grad_dict(var_to_app_to_idx,
raise TypeError(('%s.grad returned None for' + raise TypeError(('%s.grad returned None for' +
' a gradient term, ' ' a gradient term, '
'this is prohibited. Instead of None,' 'this is prohibited. Instead of None,'
'return zeros_like(input), DisconnectedType()(),' 'return zeros_like(input), disconnected_type(),'
' or a NullType variable such as those made with ' ' or a NullType variable such as those made with '
'the grad_undefined or grad_unimplemented helper ' 'the grad_undefined or grad_unimplemented helper '
'functions.') % node.op) 'functions.') % node.op)
...@@ -1258,14 +1259,14 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1258,14 +1259,14 @@ def _populate_grad_dict(var_to_app_to_idx,
# extraneous TensorConstant(0) # extraneous TensorConstant(0)
grad_dict[var] = reduce(lambda x, y: x + y, terms) grad_dict[var] = reduce(lambda x, y: x + y, terms)
else: else:
grad_dict[var] = DisconnectedType()() grad_dict[var] = disconnected_type()
if cost_name is not None and var.name is not None: if cost_name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name) grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name)
else: else:
# this variable isn't connected to the cost in the # this variable isn't connected to the cost in the
# computational graph # computational graph
grad_dict[var] = DisconnectedType()() grad_dict[var] = disconnected_type()
# end if cache miss # end if cache miss
return grad_dict[var] return grad_dict[var]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论