提交 e0978aa9 authored 作者: Frederic's avatar Frederic

rename opt local_sum_broadcast to local_reduce_broadcast.

It work for all CAReduce op.
上级 44d43bcf
...@@ -3151,7 +3151,7 @@ def local_cut_useless_reduce(node): ...@@ -3151,7 +3151,7 @@ def local_cut_useless_reduce(node):
#@register_canonicalize #@register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_sum_broadcastable(node): def local_reduce_broadcastable(node):
"""Remove reduction over broadcastable dimensions""" """Remove reduction over broadcastable dimensions"""
if isinstance(node.op, T.CAReduce): if isinstance(node.op, T.CAReduce):
reduced, = node.inputs reduced, = node.inputs
......
...@@ -3313,7 +3313,7 @@ class T_local_sum(unittest.TestCase): ...@@ -3313,7 +3313,7 @@ class T_local_sum(unittest.TestCase):
finally: finally:
config.on_opt_error = backup config.on_opt_error = backup
def test_local_sum_broadcast_all_0(self): def test_local_reduce_broadcast_all_0(self):
optimizer = optdb.query(self.mode._optimizer) optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True, True))() x = T.TensorType('int64', (True, True, True))()
...@@ -3323,7 +3323,7 @@ class T_local_sum(unittest.TestCase): ...@@ -3323,7 +3323,7 @@ class T_local_sum(unittest.TestCase):
isinstance(node.op, T.CAReduce) isinstance(node.op, T.CAReduce)
for node in g.toposort()]) for node in g.toposort()])
def test_local_sum_broadcast_all_1(self): def test_local_reduce_broadcast_all_1(self):
optimizer = optdb.query(self.mode._optimizer) optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True))() x = T.TensorType('int64', (True, True))()
...@@ -3333,7 +3333,7 @@ class T_local_sum(unittest.TestCase): ...@@ -3333,7 +3333,7 @@ class T_local_sum(unittest.TestCase):
isinstance(node.op, T.CAReduce) isinstance(node.op, T.CAReduce)
for node in g.toposort()]) for node in g.toposort()])
def test_local_sum_broadcast_some_0(self): def test_local_reduce_broadcast_some_0(self):
optimizer = optdb.query(self.mode._optimizer) optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))() x = T.TensorType('int64', (True, False, True))()
...@@ -3348,12 +3348,12 @@ class T_local_sum(unittest.TestCase): ...@@ -3348,12 +3348,12 @@ class T_local_sum(unittest.TestCase):
op = node.op op = node.op
assert isinstance(op, T.CAReduce) assert isinstance(op, T.CAReduce)
# -- the leading broadcastable dimension has been dropped # -- the leading broadcastable dimension has been dropped
# by the local_sum_broadcastable optimization # by the local_reduce_broadcastable optimization
# now summation is over the original x's dimension 1. # now summation is over the original x's dimension 1.
assert node.inputs[0].ndim == 2, node assert node.inputs[0].ndim == 2, node
assert op.axis == (0,), op.axis assert op.axis == (0,), op.axis
def test_local_sum_broadcast_some_1(self): def test_local_reduce_broadcast_some_1(self):
optimizer = optdb.query(self.mode._optimizer) optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))() x = T.TensorType('int64', (True, False, True))()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论