提交 6d516f5c authored 作者: Frederic Bastien's avatar Frederic Bastien

generate the new warning only when we had bad code.

上级 3a1ff68e
...@@ -1709,8 +1709,6 @@ def local_sum_sum(node): ...@@ -1709,8 +1709,6 @@ def local_sum_sum(node):
# do it all at once # do it all at once
return [T.Sum(None)(summed.owner.inputs[0])] return [T.Sum(None)(summed.owner.inputs[0])]
if theano.config.warn.sum_sum_bug:
print "WARNING: Theano version between version 9923a40c7b7a and the w august 2010(fixed date), generated an error in that case. This happen when their is 2 consecutive sum in the graph, bad code was generated. To disable this warning, set the theano flags warn.sum_sum_bug to False."
newaxis=list(tuple(summed.owner.op.axis)) newaxis=list(tuple(summed.owner.op.axis))
# figure out which dimensions of the original input are preserved # figure out which dimensions of the original input are preserved
for i in node.op.axis: for i in node.op.axis:
...@@ -1722,6 +1720,17 @@ def local_sum_sum(node): ...@@ -1722,6 +1720,17 @@ def local_sum_sum(node):
newaxis.append(new_i) newaxis.append(new_i)
assert len(newaxis)==len(list(summed.owner.op.axis)+list(node.op.axis)) assert len(newaxis)==len(list(summed.owner.op.axis)+list(node.op.axis))
#The old bugged logic. We keep it their to generate a warning when we generated bad code.
alldims = range(summed.owner.inputs[0].type.ndim)
alldims = [d for i,d in enumerate(alldims) if i in summed.owner.op.axis]
alldims = [d for i,d in enumerate(alldims) if i in node.op.axis]
newaxis_old = [i for i in xrange(summed.owner.inputs[0].type.ndim)
if i not in alldims]
if theano.config.warn.sum_sum_bug and newaxis!=newaxis_old:
print "WARNING: Theano version between version 9923a40c7b7a and the 2 august 2010(fixation date), generated an error in that case. This happen when their is 2 consecutive sum in the graph, bad code was generated. To disable this warning, set the theano flags warn.sum_sum_bug to False."
combined_sum = T.Sum(newaxis) combined_sum = T.Sum(newaxis)
return [combined_sum(summed.owner.inputs[0])] return [combined_sum(summed.owner.inputs[0])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论