提交 584496dc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove useless conjugate Ops from graphs

上级 ecd6b49c
......@@ -26,8 +26,9 @@ from aesara.sparse.type import SparseTensorType, _is_sparse
from aesara.sparse.utils import hash_from_sparse
from aesara.tensor import basic as at
from aesara.tensor.basic import Split
from aesara.tensor.math import _conj
from aesara.tensor.math import add as at_add
from aesara.tensor.math import arcsin, arcsinh, arctan, arctanh, ceil, conj, deg2rad
from aesara.tensor.math import arcsin, arcsinh, arctan, arctanh, ceil, deg2rad
from aesara.tensor.math import dot as at_dot
from aesara.tensor.math import exp, expm1, floor, log, log1p, maximum, minimum
from aesara.tensor.math import pow as at_pow
......@@ -322,7 +323,6 @@ def override_dense(*methods):
"max",
"argmin",
"argmax",
"conj",
"round",
"trace",
"cumsum",
......@@ -451,6 +451,9 @@ class _sparse_py_operators(_tensor_py_operators):
ret = get_item_2d(self, args)
return ret
def conj(self):
return conjugate(self)
class SparseVariable(_sparse_py_operators, TensorVariable):
format = property(lambda self: self.type.format)
......@@ -3548,8 +3551,8 @@ def sqrt(x):
# see decorator for function body
@structured_monoid(conj) # type: ignore[no-redef]
def conj(x):
@structured_monoid(_conj) # type: ignore[no-redef]
def _conj(x):
"""
Elemwise complex conjugate of `x`.
......@@ -3557,6 +3560,16 @@ def conj(x):
# see decorator for function body
def conjugate(x):
_x = as_sparse_variable(x)
if _x.type.dtype not in complex_dtypes:
return _x
return _conj(_x)
conj = conjugate
class TrueDot(Op):
# TODO
......
......@@ -2,6 +2,7 @@ import numpy as np
import scipy.sparse
import aesara
from aesara import scalar as aes
from aesara.graph.type import HasDataType
from aesara.tensor.type import TensorType
......@@ -106,22 +107,28 @@ class SparseTensorType(TensorType, HasDataType):
and value.dtype == self.dtype
):
return value
if strict:
raise TypeError(
f"{value} is not sparse, or not the right dtype (is {value.dtype}, "
f"expected {self.dtype})"
)
# The input format could be converted here
if allow_downcast:
sp = self.format_cls[self.format](value, dtype=self.dtype)
else:
sp = self.format_cls[self.format](value)
if str(sp.dtype) != self.dtype:
data = self.format_cls[self.format](value)
up_dtype = aes.upcast(self.dtype, data.dtype)
if up_dtype != self.dtype:
raise NotImplementedError(
f"Expected {self.dtype} dtype but got {sp.dtype}"
f"Expected {self.dtype} dtype but got {data.dtype}"
)
sp = data.astype(up_dtype)
if sp.format != self.format:
raise NotImplementedError()
return sp
@classmethod
......
......@@ -1465,11 +1465,21 @@ def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components"""
@scalar_elemwise
def conj(z):
@scalar_elemwise(symbolname="conj")
def _conj(z):
"""Return the complex conjugate of `z`."""
def conjugate(x):
_x = as_tensor_variable(x)
if _x.type.dtype not in complex_dtypes:
return _x
return _conj(_x)
conj = conjugate
@scalar_elemwise
def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification."""
......@@ -2931,6 +2941,7 @@ __all__ = [
"angle",
"complex",
"conj",
"conjugate",
"complex_from_polar",
"sum",
"prod",
......
......@@ -58,6 +58,7 @@ from aesara.tensor.math import (
Prod,
ProdWithoutZeros,
Sum,
_conj,
)
from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import (
......@@ -86,6 +87,7 @@ from aesara.tensor.math import true_div
from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import (
complex_dtypes,
uint_dtypes,
values_eq_approx_remove_inf,
values_eq_approx_remove_inf_nan,
......@@ -3552,3 +3554,13 @@ local_sigmoid_logit = PatternSub(
)
register_canonicalize(local_sigmoid_logit)
register_specialize(local_sigmoid_logit)
@register_canonicalize
@register_useless
@local_optimizer([_conj])
def local_useless_conj(fgraph, node):
r"""Remove `conj` `Op`\s applied to non-imaginary variable types."""
x = node.inputs[0]
if x.type.dtype not in complex_dtypes:
return [x]
......@@ -3170,6 +3170,15 @@ SqrtTester = elemwise_checker(sparse.sqrt, np.sqrt, gap=(0, 10))
ConjTester = elemwise_checker(sparse.conj, np.conj, grad_test=False)
def test_useless_conj():
x = sparse.SparseTensorType("csr", dtype="complex128")()
assert x.conj() is not x
# No conjugate when the data type isn't complex
x = sparse.SparseTensorType("csr", dtype="float64")()
assert x.conj() is x
class TestMulSV:
def test_mul_s_v_grad(self):
sp_types = {"csc": sp.sparse.csc_matrix, "csr": sp.sparse.csr_matrix}
......
......@@ -13,69 +13,72 @@ from aesara.tensor.type import DenseTensorType
class TestSparseVariable:
@pytest.mark.parametrize(
"method, exp_type, cm",
"method, exp_type, cm, x",
[
("__abs__", DenseTensorType, None),
("__neg__", SparseTensorType, ExitStack()),
("__ceil__", DenseTensorType, None),
("__floor__", DenseTensorType, None),
("__trunc__", DenseTensorType, None),
("transpose", DenseTensorType, None),
("any", DenseTensorType, None),
("all", DenseTensorType, None),
("flatten", DenseTensorType, None),
("ravel", DenseTensorType, None),
("arccos", DenseTensorType, None),
("arcsin", DenseTensorType, None),
("arctan", DenseTensorType, None),
("arccosh", DenseTensorType, None),
("arcsinh", DenseTensorType, None),
("arctanh", DenseTensorType, None),
("ceil", DenseTensorType, None),
("cos", DenseTensorType, None),
("cosh", DenseTensorType, None),
("deg2rad", DenseTensorType, None),
("exp", DenseTensorType, None),
("exp2", DenseTensorType, None),
("expm1", DenseTensorType, None),
("floor", DenseTensorType, None),
("log", DenseTensorType, None),
("log10", DenseTensorType, None),
("log1p", DenseTensorType, None),
("log2", DenseTensorType, None),
("rad2deg", DenseTensorType, None),
("sin", DenseTensorType, None),
("sinh", DenseTensorType, None),
("sqrt", DenseTensorType, None),
("tan", DenseTensorType, None),
("tanh", DenseTensorType, None),
("copy", DenseTensorType, None),
("sum", DenseTensorType, ExitStack()),
("prod", DenseTensorType, None),
("mean", DenseTensorType, None),
("var", DenseTensorType, None),
("std", DenseTensorType, None),
("min", DenseTensorType, None),
("max", DenseTensorType, None),
("argmin", DenseTensorType, None),
("argmax", DenseTensorType, None),
("nonzero", DenseTensorType, ExitStack()),
("nonzero_values", DenseTensorType, None),
("argsort", DenseTensorType, ExitStack()),
("conj", DenseTensorType, None),
("round", DenseTensorType, None),
("trace", DenseTensorType, None),
("zeros_like", SparseTensorType, ExitStack()),
("ones_like", DenseTensorType, ExitStack()),
("cumsum", DenseTensorType, None),
("cumprod", DenseTensorType, None),
("ptp", DenseTensorType, None),
("squeeze", DenseTensorType, None),
("diagonal", DenseTensorType, None),
("__abs__", DenseTensorType, None, None),
("__neg__", SparseTensorType, ExitStack(), None),
("__ceil__", DenseTensorType, None, None),
("__floor__", DenseTensorType, None, None),
("__trunc__", DenseTensorType, None, None),
("transpose", DenseTensorType, None, None),
("any", DenseTensorType, None, None),
("all", DenseTensorType, None, None),
("flatten", DenseTensorType, None, None),
("ravel", DenseTensorType, None, None),
("arccos", DenseTensorType, None, None),
("arcsin", DenseTensorType, None, None),
("arctan", DenseTensorType, None, None),
("arccosh", DenseTensorType, None, None),
("arcsinh", DenseTensorType, None, None),
("arctanh", DenseTensorType, None, None),
("ceil", DenseTensorType, None, None),
("cos", DenseTensorType, None, None),
("cosh", DenseTensorType, None, None),
("deg2rad", DenseTensorType, None, None),
("exp", DenseTensorType, None, None),
("exp2", DenseTensorType, None, None),
("expm1", DenseTensorType, None, None),
("floor", DenseTensorType, None, None),
("log", DenseTensorType, None, None),
("log10", DenseTensorType, None, None),
("log1p", DenseTensorType, None, None),
("log2", DenseTensorType, None, None),
("rad2deg", DenseTensorType, None, None),
("sin", DenseTensorType, None, None),
("sinh", DenseTensorType, None, None),
("sqrt", DenseTensorType, None, None),
("tan", DenseTensorType, None, None),
("tanh", DenseTensorType, None, None),
("copy", DenseTensorType, None, None),
("sum", DenseTensorType, ExitStack(), None),
("prod", DenseTensorType, None, None),
("mean", DenseTensorType, None, None),
("var", DenseTensorType, None, None),
("std", DenseTensorType, None, None),
("min", DenseTensorType, None, None),
("max", DenseTensorType, None, None),
("argmin", DenseTensorType, None, None),
("argmax", DenseTensorType, None, None),
("nonzero", DenseTensorType, ExitStack(), None),
("nonzero_values", DenseTensorType, None, None),
("argsort", DenseTensorType, ExitStack(), None),
("conj", SparseTensorType, ExitStack(), at.cmatrix("x")),
("round", DenseTensorType, None, None),
("trace", DenseTensorType, None, None),
("zeros_like", SparseTensorType, ExitStack(), None),
("ones_like", DenseTensorType, ExitStack(), None),
("cumsum", DenseTensorType, None, None),
("cumprod", DenseTensorType, None, None),
("ptp", DenseTensorType, None, None),
("squeeze", DenseTensorType, None, None),
("diagonal", DenseTensorType, None, None),
],
)
def test_unary(self, method, exp_type, cm):
def test_unary(self, method, exp_type, cm, x):
if x is None:
x = at.dmatrix("x")
x = sparse.csr_from_dense(x)
method_to_call = getattr(x, method)
......@@ -98,7 +101,7 @@ class TestSparseVariable:
assert all(isinstance(out.type, exp_type) for out in z_outs)
f = aesara.function([x], z, on_unused_input="ignore")
f = aesara.function([x], z, on_unused_input="ignore", allow_input_downcast=True)
res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
......
......@@ -696,7 +696,7 @@ TestComplexFromPolarBroadcast = makeBroadcastTester(
)
TestConjBroadcast = makeBroadcastTester(
op=conj, expected=np.conj, good=_good_broadcast_unary_normal
op=conj, expected=np.conj, good={"complex": _good_broadcast_unary_normal["complex"]}
)
......@@ -2567,6 +2567,10 @@ class TestTensorInstanceMethods:
assert_array_equal(Z.conj().eval({Z: z}), z.conj())
assert_array_equal(Z.conjugate().eval({Z: z}), z.conj())
# No conjugate when the data type isn't complex
assert X.type.dtype not in complex_dtypes
assert X.conj() is X
def test_round(self):
X, _ = self.vars
x, _ = self.vals
......
......@@ -34,7 +34,7 @@ from aesara.tensor.basic_opt import local_dimshuffle_lift
from aesara.tensor.blas import Dot22, Gemv
from aesara.tensor.blas_c import CGemv
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import Dot, MaxAndArgmax, Prod, Sum
from aesara.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import add
from aesara.tensor.math import all as at_all
......@@ -119,6 +119,7 @@ from aesara.tensor.type import (
values_eq_approx_remove_nan,
vector,
vectors,
zscalar,
)
from aesara.tensor.var import TensorConstant
from tests import unittest_tools as utt
......@@ -4619,3 +4620,20 @@ def test_local_logit_sigmoid():
fg = optimize(FunctionGraph([x], [out]))
assert not list(fg.toposort())
assert fg.inputs[0] is fg.outputs[0]
def test_local_useless_conj():
default_mode = get_default_mode()
# Test for all zeros
x = scalar()
s = _conj(x)
mode_with_opt = default_mode.including("canonicalization", "local_useless_conj")
f = function([x], s, mode=mode_with_opt)
assert not any(node.op == _conj for node in f.maker.fgraph.apply_nodes)
x = zscalar()
s = _conj(x)
mode_with_opt = default_mode.including("canonicalization", "local_useless_conj")
f = function([x], s, mode=mode_with_opt)
assert any(node.op == _conj for node in f.maker.fgraph.apply_nodes)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论