提交 a8e715d9 authored 作者: Chinnadhurai Sankar's avatar Chinnadhurai Sankar 提交者: Pascal Lamblin

custom bitwise_and.reduce function using identity value -1

上级 9d887ad2
...@@ -330,6 +330,22 @@ class test_Broadcast(unittest.TestCase): ...@@ -330,6 +330,22 @@ class test_Broadcast(unittest.TestCase):
assert (f(xv) == zv).all() assert (f(xv) == zv).all()
def reduce_bitwise_and(x, axis=-1, dtype='int8'):
identity = numpy.array((-1,), dtype=dtype)[0]
if 0 in x.shape and x.shape[axis] != 0:
new_shape = tuple([s for i, s in enumerate(x.shape) if i != axis])
return numpy.empty(shape=new_shape, dtype=x.dtype)
def custom_reduce(a):
out = identity
for i in range(a.size):
out = numpy.bitwise_and(a[i], out)
return out
return numpy.apply_along_axis(custom_reduce, axis, x)
class test_CAReduce(unittest_tools.InferShapeTester): class test_CAReduce(unittest_tools.InferShapeTester):
op = CAReduce op = CAReduce
cases = [((5, 6), None), cases = [((5, 6), None),
...@@ -430,7 +446,7 @@ class test_CAReduce(unittest_tools.InferShapeTester): ...@@ -430,7 +446,7 @@ class test_CAReduce(unittest_tools.InferShapeTester):
zv = numpy.bitwise_or.reduce(zv, axis) zv = numpy.bitwise_or.reduce(zv, axis)
elif scalar_op == scalar.and_: elif scalar_op == scalar.and_:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.bitwise_and.reduce(zv, axis) zv = reduce_bitwise_and(zv, axis, dtype=dtype)
elif scalar_op == scalar.xor: elif scalar_op == scalar.xor:
# There is no identity value for the xor function # There is no identity value for the xor function
# So we can't support shape of dimensions 0. # So we can't support shape of dimensions 0.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论