提交 f85d7e34 authored 作者: Frederic's avatar Frederic

rename opt local_sum_broadcast to local_reduce_broadcast, update test and fix crash.

This new opt was crashing with the tensor.min() tensor.max() CAReduce. I fix this. I test the opt with many more reduction then just sum.
上级 e0978aa9
......@@ -3176,7 +3176,12 @@ def local_reduce_broadcastable(node):
ii += 1
new_reduced = reduced.dimshuffle(*pattern)
if new_axis:
new_op = node.op.__class__(axis=new_axis)
if type(node.op) == theano.tensor.elemwise.CAReduce:
# This happen for tensor.max(), tensor.min()
new_op = node.op.__class__(node.op.scalar_op,
axis=new_axis)
else:
new_op = node.op.__class__(axis=new_axis)
return [new_op(new_reduced)]
else:
# -- in this case we can remove the reduction completely
......
......@@ -3313,54 +3313,59 @@ class T_local_sum(unittest.TestCase):
finally:
config.on_opt_error = backup
def test_local_reduce_broadcast_all_0(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True, True))()
g = FunctionGraph([x], [x.sum()])
optimizer.optimize(g)
assert not any([
isinstance(node.op, T.CAReduce)
for node in g.toposort()])
class T_local_reduce(unittest.TestCase):
def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize',
'specialize')
def test_local_reduce_broadcast_all_1(self):
optimizer = optdb.query(self.mode._optimizer)
def test_local_reduce_broadcast_all_0(self):
for fct in [tensor.sum, tensor.all, tensor.any, tensor.prod,
tensor.max, tensor.min]:
x = T.TensorType('int64', (True, True, True))()
f = theano.function([x], [fct(x)], mode=self.mode)
assert not any([
isinstance(node.op, T.CAReduce)
for node in f.maker.fgraph.toposort()])
x = T.TensorType('int64', (True, True))()
g = FunctionGraph([x], [x.sum(axis=[0, 1])])
optimizer.optimize(g)
assert not any([
isinstance(node.op, T.CAReduce)
for node in g.toposort()])
def test_local_reduce_broadcast_all_1(self):
for fct in [tensor.sum, tensor.all, tensor.any, tensor.prod,
tensor.max, tensor.min]:
x = T.TensorType('int64', (True, True))()
f = theano.function([x], [fct(x, axis=[0, 1])], mode=self.mode)
assert not any([
isinstance(node.op, T.CAReduce)
for node in f.maker.fgraph.toposort()])
def test_local_reduce_broadcast_some_0(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))()
g = FunctionGraph([x], [x.sum(axis=[0, 1])])
optimizer.optimize(g)
order = g.toposort()
assert 1 == sum([isinstance(node.op, T.CAReduce) for node in order])
if config.mode == 'FAST_COMPILE':
node = order[-1]
else:
node = order[-2]
op = node.op
assert isinstance(op, T.CAReduce)
# -- the leading broadcastable dimension has been dropped
# by the local_reduce_broadcastable optimization
# now summation is over the original x's dimension 1.
assert node.inputs[0].ndim == 2, node
assert op.axis == (0,), op.axis
for fct in [tensor.sum, tensor.all, tensor.any, tensor.prod,
tensor.max, tensor.min]:
x = T.TensorType('int64', (True, False, True))()
f = theano.function([x], [fct(x, axis=[0, 1])], mode=self.mode)
order = f.maker.fgraph.toposort()
assert 1 == sum([isinstance(node.op, T.CAReduce)
for node in order])
node = [node for node in order if isinstance(node.op,
tensor.CAReduce)][0]
op = node.op
assert isinstance(op, T.CAReduce)
# -- the leading broadcastable dimension has been dropped
# by the local_reduce_broadcastable optimization
# now summation is over the original x's dimension 1.
assert node.inputs[0].ndim == 2, node
assert op.axis == (0,), op.axis
def test_local_reduce_broadcast_some_1(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))()
g = FunctionGraph([x], [x.sum(axis=[0, 2])])
optimizer.optimize(g)
order = g.toposort()
assert 0 == sum([isinstance(node.op, T.CAReduce) for node in order])
for fct in [tensor.sum, tensor.all, tensor.any, tensor.prod,
tensor.max, tensor.min]:
x = T.TensorType('int64', (True, True, True))()
f = theano.function([x], [fct(x, axis=[0, 2])], mode=self.mode)
assert not any([
isinstance(node.op, T.CAReduce)
for node in f.maker.fgraph.toposort()])
class T_local_sum_dimshuffle(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论