提交 a5fe2dd0 authored 作者: James Bergstra's avatar James Bergstra

local_sum_broadcastable - unit tests and more complete impl

上级 de93a9fe
......@@ -3140,16 +3140,34 @@ def local_cut_useless_reduce(node):
@gof.local_optimizer([])
def local_sum_broadcastable(node):
"""Remove reduction over broadcastable dimensions"""
if isinstance(node.op, T.CAReduce) and node.op.axis is not None:
if isinstance(node.op, T.CAReduce):
reduced, = node.inputs
axis = list(node.op.axis)
cuttable = [a for a in axis if reduced.broadcastable[a]]
if cuttable == axis:
# -- in this case we can remove the reduction completely
pattern = [p for p in range(reduced.ndim) if p not in cuttable]
rval = reduced.dimshuffle(*pattern)
return [rval]
odtype = node.outputs[0].dtype
if node.op.axis is None:
if all(reduced.broadcastable):
return [reduced.dimshuffle().astype(odtype)]
else:
axis = list(node.op.axis)
cuttable = [a for a in axis if reduced.broadcastable[a]]
if cuttable:
# -- we can remove some axes of summation,
# which simplifies the codegen for sum, especially on GPU
new_axis = []
pattern = []
ii = 0
for p in range(reduced.ndim):
if p not in cuttable:
if p in axis:
new_axis.append(ii)
pattern.append(p)
ii += 1
new_reduced = reduced.dimshuffle(*pattern)
if new_axis:
new_op = node.op.__class__(axis=new_axis)
return [new_op(new_reduced)]
else:
# -- in this case we can remove the reduction completely
return [new_reduced.astype(odtype)]
@register_specialize
@gof.local_optimizer([])
......
......@@ -46,6 +46,7 @@ from theano.tensor import (
)
from theano.tensor.elemwise import DimShuffle
from theano.tests import unittest_tools as utt
from theano.compile.mode import optdb
mode_opt = theano.config.mode
if mode_opt == 'FAST_COMPILE':
......@@ -3288,6 +3289,51 @@ class T_local_sum(unittest.TestCase):
finally:
config.on_opt_error = backup
def test_local_sum_broadcast_all_0(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True, True))()
env = Env([x], [x.sum()])
optimizer.optimize(env)
assert not any([
isinstance(node.op, T.CAReduce)
for node in env.toposort()])
def test_local_sum_broadcast_all_1(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True))()
env = Env([x], [x.sum(axis=[0, 1])])
optimizer.optimize(env)
assert not any([
isinstance(node.op, T.CAReduce)
for node in env.toposort()])
def test_local_sum_broadcast_some_0(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))()
env = Env([x], [x.sum(axis=[0, 1])])
optimizer.optimize(env)
order = env.toposort()
assert 1 == sum([isinstance(node.op, T.CAReduce) for node in order])
op = order[-2].op
assert isinstance(op, T.CAReduce)
# -- the leading broadcastable dimension has been dropped
# by the local_sum_broadcastable optimization
# now summation is over the original x's dimension 1.
assert order[-2].inputs[0].ndim == 2, order[-2]
assert op.axis == (0,), op.axis
def test_local_sum_broadcast_some_1(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))()
env = Env([x], [x.sum(axis=[0, 2])])
optimizer.optimize(env)
order = env.toposort()
assert 0 == sum([isinstance(node.op, T.CAReduce) for node in order])
class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论