Unverified 提交 84936418 authored 作者: Kaustubh's avatar Kaustubh 提交者: GitHub

Refactor and simplify CAReduce, Max, and Min Ops (#297)

上级 584c0c15
......@@ -3024,11 +3024,9 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype):
"""
def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None):
if not hasattr(scalar_op, "identity"):
if scalar_op.identity is None:
raise ValueError("No identity on scalar op")
CAReduceDtype.__init__(
self, scalar_op, axis=axis, dtype=dtype, acc_dtype=acc_dtype
)
super().__init__(scalar_op, axis=axis, dtype=dtype, acc_dtype=acc_dtype)
def __str__(self):
ax = ""
......@@ -3038,7 +3036,7 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype):
def make_node(self, input):
ctx_name = infer_context_name(input)
res = CAReduceDtype.make_node(self, input)
res = super().make_node(input)
input = as_gpuarray_variable(input, ctx_name)
otype = GpuArrayType(
dtype=res.outputs[0].dtype,
......
......@@ -1259,19 +1259,19 @@ class UnaryScalarOp(ScalarOp):
class BinaryScalarOp(ScalarOp):
# One may define in subclasses the following fields:
# - `identity`: for an associative operation, identity corresponds to
# the neutral element. For instance, it will be 0 for addition, 1 for
# multiplication, True for "and", False for "or".
# - `commutative`: whether op(a, b) == op(b, a)
# - `associative`: whether op(op(a, b), c) == op(a, op(b, c))
commutative = None
associative = None
identity = None
"""
For an associative operation, the identity object corresponds to the neutral
element. For instance, it will be ``0`` for addition, ``1`` for multiplication,
``True`` for ``and``, ``False`` for ``or``.
"""
nin = 2
###############
# Comparisons
###############
class LogicalComparison(BinaryScalarOp):
def __init__(self, *args, **kwargs):
BinaryScalarOp.__init__(self, *args, **kwargs)
......@@ -1725,6 +1725,7 @@ class ScalarMaximum(BinaryScalarOp):
associative = True
nfunc_spec = ("maximum", 2, 1)
nfunc_variadic = "maximum"
identity = -np.inf
def impl(self, *inputs):
# The built-in max function don't support complex type
......@@ -1767,6 +1768,7 @@ class ScalarMinimum(BinaryScalarOp):
associative = True
nfunc_spec = ("minimum", 2, 1)
nfunc_variadic = "minimum"
identity = np.inf
def impl(self, *inputs):
# The built-in min function don't support complex type
......
差异被折叠。
......@@ -585,14 +585,45 @@ def max_and_argmax(a, axis=None, keepdims=False):
return [out, argout]
class Max(CAReduce):
class NonZeroCAReduce(CAReduce):
def _c_all(self, node, name, inames, onames, sub):
decl, checks, alloc, loop, end = super()._c_all(node, name, inames, onames, sub)
# We add an additional check for zero-sized dimensions (This seems like
# something that could enabled in `elemwise_cgen.make_checks`.)
iname = inames[0]
axis = self.axis
if axis is None:
axis = list(range(len(node.inputs[0].type.broadcastable)))
pattern = [0] * len(node.inputs[0].broadcastable)
for i in axis:
pattern[i] = 1
pattern_ = str(pattern)[1:-1]
decl += f"""int tosum[]={{{pattern_}}};"""
alloc += f"""
for(int i=0;i<PyArray_NDIM({iname});i++){{
if(PyArray_DIMS({iname})[i]==0 && tosum[i]){{
PyErr_Format(PyExc_ValueError,
"Input of CAReduce{{{node.op.scalar_op}}} has zero-size on axis %%d",i);
{sub["fail"]};
}}
}}
"""
return decl, checks, alloc, loop, end
class Max(NonZeroCAReduce):
nfunc_spec = ("max", 1, 1)
def __init__(self, axis):
super().__init__(aes.scalar_maximum, axis)
class Min(CAReduce):
class Min(NonZeroCAReduce):
nfunc_spec = ("min", 1, 1)
def __init__(self, axis):
......
......@@ -60,6 +60,7 @@ from aesara.tensor.math import (
All,
Any,
Dot,
NonZeroCAReduce,
Prod,
ProdWithoutZeros,
Sum,
......@@ -1534,14 +1535,18 @@ def local_op_of_op(fgraph, node):
return [combined(node_inps.owner.inputs[0])]
ALL_REDUCE = [
CAReduce,
All,
Any,
Sum,
Prod,
ProdWithoutZeros,
] + CAReduce.__subclasses__()
ALL_REDUCE = (
[
CAReduce,
All,
Any,
Sum,
Prod,
ProdWithoutZeros,
]
+ CAReduce.__subclasses__()
+ NonZeroCAReduce.__subclasses__()
)
@register_canonicalize
......
......@@ -372,7 +372,7 @@ class TestCAReduce(unittest_tools.InferShapeTester):
zv = xv
if pre_scalar_op is not None:
zv = Elemwise(scalar_op=pre_scalar_op)(x).eval({x: xv})
numpy_raised = False
if len(tosum) > 1 and any([a < 0 for a in tosum]):
# In that case, we need to use the good order of axis
# in the reduction.
......@@ -404,17 +404,19 @@ class TestCAReduce(unittest_tools.InferShapeTester):
for axis in reversed(sorted(tosum)):
zv = np.multiply.reduce(zv, axis)
elif scalar_op == aes.scalar_maximum:
try:
for axis in reversed(sorted(tosum)):
zv = np.maximum.reduce(zv, axis)
except ValueError:
numpy_raised = True
# There is no identity value for the maximum function
# So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0:
continue
for axis in reversed(sorted(tosum)):
zv = np.maximum.reduce(zv, axis)
elif scalar_op == aes.scalar_minimum:
try:
for axis in reversed(sorted(tosum)):
zv = np.minimum.reduce(zv, axis)
except ValueError:
numpy_raised = True
# There is no identity value for the minimum function
# So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0:
continue
for axis in reversed(sorted(tosum)):
zv = np.minimum.reduce(zv, axis)
elif scalar_op == aes.or_:
for axis in reversed(sorted(tosum)):
zv = np.bitwise_or.reduce(zv, axis)
......@@ -432,24 +434,21 @@ class TestCAReduce(unittest_tools.InferShapeTester):
raise Exception(
f"Test for CAReduce with scalar_op {scalar_op} not implemented"
)
if scalar_op in [aes.scalar_maximum, aes.scalar_minimum] and numpy_raised:
with pytest.raises(ValueError):
f(xv)
if test_nan:
try:
assert self.type.values_eq(f(xv), zv), (f(xv), zv)
except NotImplementedError:
# GpuCAReduce don't implement all cases when size is 0
assert xv.size == 0
else:
if test_nan:
try:
assert self.type.values_eq(f(xv), zv), (f(xv), zv)
except NotImplementedError:
# GpuCAReduce don't implement all cases when size is 0
assert xv.size == 0
else:
try:
f_xv = f(xv)
assert f_xv.shape == zv.shape, (f_xv, zv)
utt.assert_allclose(zv, f_xv)
except NotImplementedError:
# GpuCAReduce don't implement all cases when size is 0
assert xv.size == 0
try:
f_xv = f(xv)
assert f_xv.shape == zv.shape, (f_xv, zv)
utt.assert_allclose(zv, f_xv)
except NotImplementedError:
# GpuCAReduce don't implement all cases when size is 0
assert xv.size == 0
x = self.type(dtype, [(entry == 1) for entry in xsh])("x")
if tensor_op is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论