提交 f85d7e34 authored 作者: Frederic's avatar Frederic

rename opt local_sum_broadcast to local_reduce_broadcast, update test and fix crash.

This new opt was crashing with the tensor.min() tensor.max() CAReduce. I fix this. I test the opt with many more reduction then just sum.
上级 e0978aa9
...@@ -3176,7 +3176,12 @@ def local_reduce_broadcastable(node): ...@@ -3176,7 +3176,12 @@ def local_reduce_broadcastable(node):
ii += 1 ii += 1
new_reduced = reduced.dimshuffle(*pattern) new_reduced = reduced.dimshuffle(*pattern)
if new_axis: if new_axis:
new_op = node.op.__class__(axis=new_axis) if type(node.op) == theano.tensor.elemwise.CAReduce:
# This happen for tensor.max(), tensor.min()
new_op = node.op.__class__(node.op.scalar_op,
axis=new_axis)
else:
new_op = node.op.__class__(axis=new_axis)
return [new_op(new_reduced)] return [new_op(new_reduced)]
else: else:
# -- in this case we can remove the reduction completely # -- in this case we can remove the reduction completely
......
...@@ -3313,54 +3313,59 @@ class T_local_sum(unittest.TestCase): ...@@ -3313,54 +3313,59 @@ class T_local_sum(unittest.TestCase):
finally: finally:
config.on_opt_error = backup config.on_opt_error = backup
def test_local_reduce_broadcast_all_0(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True, True))() class T_local_reduce(unittest.TestCase):
g = FunctionGraph([x], [x.sum()]) def setUp(self):
optimizer.optimize(g) self.mode = theano.compile.get_default_mode().including('canonicalize',
assert not any([ 'specialize')
isinstance(node.op, T.CAReduce)
for node in g.toposort()])
def test_local_reduce_broadcast_all_1(self): def test_local_reduce_broadcast_all_0(self):
optimizer = optdb.query(self.mode._optimizer) for fct in [tensor.sum, tensor.all, tensor.any, tensor.prod,
tensor.max, tensor.min]:
x = T.TensorType('int64', (True, True, True))()
f = theano.function([x], [fct(x)], mode=self.mode)
assert not any([
isinstance(node.op, T.CAReduce)
for node in f.maker.fgraph.toposort()])
x = T.TensorType('int64', (True, True))() def test_local_reduce_broadcast_all_1(self):
g = FunctionGraph([x], [x.sum(axis=[0, 1])]) for fct in [tensor.sum, tensor.all, tensor.any, tensor.prod,
optimizer.optimize(g) tensor.max, tensor.min]:
assert not any([ x = T.TensorType('int64', (True, True))()
isinstance(node.op, T.CAReduce) f = theano.function([x], [fct(x, axis=[0, 1])], mode=self.mode)
for node in g.toposort()]) assert not any([
isinstance(node.op, T.CAReduce)
for node in f.maker.fgraph.toposort()])
def test_local_reduce_broadcast_some_0(self): def test_local_reduce_broadcast_some_0(self):
optimizer = optdb.query(self.mode._optimizer) for fct in [tensor.sum, tensor.all, tensor.any, tensor.prod,
tensor.max, tensor.min]:
x = T.TensorType('int64', (True, False, True))() x = T.TensorType('int64', (True, False, True))()
g = FunctionGraph([x], [x.sum(axis=[0, 1])]) f = theano.function([x], [fct(x, axis=[0, 1])], mode=self.mode)
optimizer.optimize(g)
order = g.toposort() order = f.maker.fgraph.toposort()
assert 1 == sum([isinstance(node.op, T.CAReduce) for node in order]) assert 1 == sum([isinstance(node.op, T.CAReduce)
if config.mode == 'FAST_COMPILE': for node in order])
node = order[-1]
else: node = [node for node in order if isinstance(node.op,
node = order[-2] tensor.CAReduce)][0]
op = node.op
assert isinstance(op, T.CAReduce) op = node.op
# -- the leading broadcastable dimension has been dropped assert isinstance(op, T.CAReduce)
# by the local_reduce_broadcastable optimization # -- the leading broadcastable dimension has been dropped
# now summation is over the original x's dimension 1. # by the local_reduce_broadcastable optimization
assert node.inputs[0].ndim == 2, node # now summation is over the original x's dimension 1.
assert op.axis == (0,), op.axis assert node.inputs[0].ndim == 2, node
assert op.axis == (0,), op.axis
def test_local_reduce_broadcast_some_1(self): def test_local_reduce_broadcast_some_1(self):
optimizer = optdb.query(self.mode._optimizer) for fct in [tensor.sum, tensor.all, tensor.any, tensor.prod,
tensor.max, tensor.min]:
x = T.TensorType('int64', (True, False, True))() x = T.TensorType('int64', (True, True, True))()
g = FunctionGraph([x], [x.sum(axis=[0, 2])]) f = theano.function([x], [fct(x, axis=[0, 2])], mode=self.mode)
optimizer.optimize(g) assert not any([
order = g.toposort() isinstance(node.op, T.CAReduce)
assert 0 == sum([isinstance(node.op, T.CAReduce) for node in order]) for node in f.maker.fgraph.toposort()])
class T_local_sum_dimshuffle(unittest.TestCase): class T_local_sum_dimshuffle(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论