提交 a5fe2dd0 authored 作者: James Bergstra's avatar James Bergstra

local_sum_broadcastable - unit tests and more complete impl

上级 de93a9fe
...@@ -3140,16 +3140,34 @@ def local_cut_useless_reduce(node): ...@@ -3140,16 +3140,34 @@ def local_cut_useless_reduce(node):
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_sum_broadcastable(node): def local_sum_broadcastable(node):
"""Remove reduction over broadcastable dimensions""" """Remove reduction over broadcastable dimensions"""
if isinstance(node.op, T.CAReduce) and node.op.axis is not None: if isinstance(node.op, T.CAReduce):
reduced, = node.inputs reduced, = node.inputs
axis = list(node.op.axis) odtype = node.outputs[0].dtype
cuttable = [a for a in axis if reduced.broadcastable[a]] if node.op.axis is None:
if cuttable == axis: if all(reduced.broadcastable):
# -- in this case we can remove the reduction completely return [reduced.dimshuffle().astype(odtype)]
pattern = [p for p in range(reduced.ndim) if p not in cuttable] else:
rval = reduced.dimshuffle(*pattern) axis = list(node.op.axis)
return [rval] cuttable = [a for a in axis if reduced.broadcastable[a]]
if cuttable:
# -- we can remove some axes of summation,
# which simplifies the codegen for sum, especially on GPU
new_axis = []
pattern = []
ii = 0
for p in range(reduced.ndim):
if p not in cuttable:
if p in axis:
new_axis.append(ii)
pattern.append(p)
ii += 1
new_reduced = reduced.dimshuffle(*pattern)
if new_axis:
new_op = node.op.__class__(axis=new_axis)
return [new_op(new_reduced)]
else:
# -- in this case we can remove the reduction completely
return [new_reduced.astype(odtype)]
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([])
......
...@@ -46,6 +46,7 @@ from theano.tensor import ( ...@@ -46,6 +46,7 @@ from theano.tensor import (
) )
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.compile.mode import optdb
mode_opt = theano.config.mode mode_opt = theano.config.mode
if mode_opt == 'FAST_COMPILE': if mode_opt == 'FAST_COMPILE':
...@@ -3288,6 +3289,51 @@ class T_local_sum(unittest.TestCase): ...@@ -3288,6 +3289,51 @@ class T_local_sum(unittest.TestCase):
finally: finally:
config.on_opt_error = backup config.on_opt_error = backup
def test_local_sum_broadcast_all_0(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True, True))()
env = Env([x], [x.sum()])
optimizer.optimize(env)
assert not any([
isinstance(node.op, T.CAReduce)
for node in env.toposort()])
def test_local_sum_broadcast_all_1(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True))()
env = Env([x], [x.sum(axis=[0, 1])])
optimizer.optimize(env)
assert not any([
isinstance(node.op, T.CAReduce)
for node in env.toposort()])
def test_local_sum_broadcast_some_0(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))()
env = Env([x], [x.sum(axis=[0, 1])])
optimizer.optimize(env)
order = env.toposort()
assert 1 == sum([isinstance(node.op, T.CAReduce) for node in order])
op = order[-2].op
assert isinstance(op, T.CAReduce)
# -- the leading broadcastable dimension has been dropped
# by the local_sum_broadcastable optimization
# now summation is over the original x's dimension 1.
assert order[-2].inputs[0].ndim == 2, order[-2]
assert op.axis == (0,), op.axis
def test_local_sum_broadcast_some_1(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))()
env = Env([x], [x.sum(axis=[0, 2])])
optimizer.optimize(env)
order = env.toposort()
assert 0 == sum([isinstance(node.op, T.CAReduce) for node in order])
class T_local_sum_dimshuffle(unittest.TestCase): class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论