提交 e0978aa9 authored 作者: Frederic's avatar Frederic

rename opt local_sum_broadcast to local_reduce_broadcast.

It work for all CAReduce op.
上级 44d43bcf
......@@ -3151,7 +3151,7 @@ def local_cut_useless_reduce(node):
#@register_canonicalize
@register_specialize
@gof.local_optimizer([])
def local_sum_broadcastable(node):
def local_reduce_broadcastable(node):
"""Remove reduction over broadcastable dimensions"""
if isinstance(node.op, T.CAReduce):
reduced, = node.inputs
......
......@@ -3313,7 +3313,7 @@ class T_local_sum(unittest.TestCase):
finally:
config.on_opt_error = backup
def test_local_sum_broadcast_all_0(self):
def test_local_reduce_broadcast_all_0(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True, True))()
......@@ -3323,7 +3323,7 @@ class T_local_sum(unittest.TestCase):
isinstance(node.op, T.CAReduce)
for node in g.toposort()])
def test_local_sum_broadcast_all_1(self):
def test_local_reduce_broadcast_all_1(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True))()
......@@ -3333,7 +3333,7 @@ class T_local_sum(unittest.TestCase):
isinstance(node.op, T.CAReduce)
for node in g.toposort()])
def test_local_sum_broadcast_some_0(self):
def test_local_reduce_broadcast_some_0(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))()
......@@ -3348,12 +3348,12 @@ class T_local_sum(unittest.TestCase):
op = node.op
assert isinstance(op, T.CAReduce)
# -- the leading broadcastable dimension has been dropped
# by the local_sum_broadcastable optimization
# by the local_reduce_broadcastable optimization
# now summation is over the original x's dimension 1.
assert node.inputs[0].ndim == 2, node
assert op.axis == (0,), op.axis
def test_local_sum_broadcast_some_1(self):
def test_local_reduce_broadcast_some_1(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论