提交 3a1ff68e authored 作者: Frederic Bastien's avatar Frederic Bastien

fix a bug in optimization of two consecutive sum. Add a warning when we hit this…

fix a bug in optimization of two consecutive sum. Add a warning when we hit this case. The warning can be disabled by the theano flags warn.sum_sum_bug=False.
上级 ebc8af30
...@@ -87,3 +87,7 @@ AddConfigVar('warn.argmax_pushdown_bug', ...@@ -87,3 +87,7 @@ AddConfigVar('warn.argmax_pushdown_bug',
AddConfigVar('warn.gpusum_01_011_0111_bug', AddConfigVar('warn.gpusum_01_011_0111_bug',
"Warn if we are in a case where old version of Theano had a silent bug with GpuSum pattern 01,011 and 0111 when the first dimensions was bigger then 4096. Was fixed 31 may 2010", "Warn if we are in a case where old version of Theano had a silent bug with GpuSum pattern 01,011 and 0111 when the first dimensions was bigger then 4096. Was fixed 31 may 2010",
BoolParam(True)) BoolParam(True))
AddConfigVar('warn.sum_sum_bug',
"Warn if we are in a case where Theano version between version 9923a40c7b7a and the 2 august 2010(fixed date), generated an error in that case. This happen when their is 2 consecutive sum in the graph, bad code was generated. Was fixed 2 August 2010",
BoolParam(True))
...@@ -1709,19 +1709,19 @@ def local_sum_sum(node): ...@@ -1709,19 +1709,19 @@ def local_sum_sum(node):
# do it all at once # do it all at once
return [T.Sum(None)(summed.owner.inputs[0])] return [T.Sum(None)(summed.owner.inputs[0])]
if theano.config.warn.sum_sum_bug:
print "WARNING: Theano version between version 9923a40c7b7a and the w august 2010(fixed date), generated an error in that case. This happen when their is 2 consecutive sum in the graph, bad code was generated. To disable this warning, set the theano flags warn.sum_sum_bug to False."
newaxis=list(tuple(summed.owner.op.axis))
# figure out which dimensions of the original input are preserved # figure out which dimensions of the original input are preserved
alldims = range(summed.owner.inputs[0].type.ndim) for i in node.op.axis:
new_i = i
# trim out the dimensions that were removed by the first sum for ii in summed.owner.op.axis:
alldims = [d for i,d in enumerate(alldims) if i in summed.owner.op.axis] if i>=ii:
new_i+=1
# trim out the dimensions removed by second sum assert new_i not in newaxis
alldims = [d for i,d in enumerate(alldims) if i in node.op.axis] newaxis.append(new_i)
# figure out an axis argument that combines the effect of both assert len(newaxis)==len(list(summed.owner.op.axis)+list(node.op.axis))
newaxis = [i for i in xrange(summed.owner.inputs[0].type.ndim)
if i not in alldims]
combined_sum = T.Sum(newaxis) combined_sum = T.Sum(newaxis)
return [combined_sum(summed.owner.inputs[0])] return [combined_sum(summed.owner.inputs[0])]
......
...@@ -1531,6 +1531,48 @@ def test_constant_get_stabilized(): ...@@ -1531,6 +1531,48 @@ def test_constant_get_stabilized():
#When this error is fixed, the following line should be ok. #When this error is fixed, the following line should be ok.
assert f()==800,f() assert f()==800,f()
class T_local_sum_canonicalize(unittest.TestCase):
def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize')
def test_local_sum_all_to_none(self):
a = T.tensor3()
input=numpy.arange(3*3*3).reshape(3,3,3)
f = theano.function([a],a.sum())
assert len(f.maker.env.nodes)==1
assert numpy.allclose(f(input),input.sum())
f = theano.function([a],a.sum([0,1,2]))
assert len(f.maker.env.nodes)==1
assert numpy.allclose(f(input),input.sum())
f = theano.function([a],a.sum(0).sum(0).sum(0))
assert len(f.maker.env.nodes)==1
assert numpy.allclose(f(input),input.sum())
def test_local_sum_sum(self):
a=T.tensor3()
input=numpy.arange(3*3*3).reshape(3,3,3)
dims=[(0,0),(1,0),(2,0),(0,1),(1,1),(2,1)]
for d,dd in dims:
f = theano.function([a],a.sum(d).sum(dd))
assert numpy.allclose(f(input),input.sum(d).sum(dd))
assert len(f.maker.env.nodes)==1
for d,dd in dims:
f = theano.function([a],a.sum(d).sum(dd).sum(0))
assert numpy.allclose(f(input),input.sum(d).sum(dd).sum(0))
assert len(f.maker.env.nodes)==1
for d in [0,1,2]:
f = theano.function([a],a.sum(d).sum(None))
assert numpy.allclose(f(input),input.sum(d).sum())
assert len(f.maker.env.nodes)==1
for d in [0,1,2]:
f = theano.function([a],a.sum(None).sum())
assert numpy.allclose(f(input),input.sum())
assert len(f.maker.env.nodes)==1
class T_local_sum_dimshuffle(unittest.TestCase): class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize') self.mode = theano.compile.get_default_mode().including('canonicalize')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论