提交 584496dc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove useless conjugate Ops from graphs

上级 ecd6b49c
...@@ -26,8 +26,9 @@ from aesara.sparse.type import SparseTensorType, _is_sparse ...@@ -26,8 +26,9 @@ from aesara.sparse.type import SparseTensorType, _is_sparse
from aesara.sparse.utils import hash_from_sparse from aesara.sparse.utils import hash_from_sparse
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor.basic import Split from aesara.tensor.basic import Split
from aesara.tensor.math import _conj
from aesara.tensor.math import add as at_add from aesara.tensor.math import add as at_add
from aesara.tensor.math import arcsin, arcsinh, arctan, arctanh, ceil, conj, deg2rad from aesara.tensor.math import arcsin, arcsinh, arctan, arctanh, ceil, deg2rad
from aesara.tensor.math import dot as at_dot from aesara.tensor.math import dot as at_dot
from aesara.tensor.math import exp, expm1, floor, log, log1p, maximum, minimum from aesara.tensor.math import exp, expm1, floor, log, log1p, maximum, minimum
from aesara.tensor.math import pow as at_pow from aesara.tensor.math import pow as at_pow
...@@ -322,7 +323,6 @@ def override_dense(*methods): ...@@ -322,7 +323,6 @@ def override_dense(*methods):
"max", "max",
"argmin", "argmin",
"argmax", "argmax",
"conj",
"round", "round",
"trace", "trace",
"cumsum", "cumsum",
...@@ -451,6 +451,9 @@ class _sparse_py_operators(_tensor_py_operators): ...@@ -451,6 +451,9 @@ class _sparse_py_operators(_tensor_py_operators):
ret = get_item_2d(self, args) ret = get_item_2d(self, args)
return ret return ret
def conj(self):
return conjugate(self)
class SparseVariable(_sparse_py_operators, TensorVariable): class SparseVariable(_sparse_py_operators, TensorVariable):
format = property(lambda self: self.type.format) format = property(lambda self: self.type.format)
...@@ -3548,8 +3551,8 @@ def sqrt(x): ...@@ -3548,8 +3551,8 @@ def sqrt(x):
# see decorator for function body # see decorator for function body
@structured_monoid(conj) # type: ignore[no-redef] @structured_monoid(_conj) # type: ignore[no-redef]
def conj(x): def _conj(x):
""" """
Elemwise complex conjugate of `x`. Elemwise complex conjugate of `x`.
...@@ -3557,6 +3560,16 @@ def conj(x): ...@@ -3557,6 +3560,16 @@ def conj(x):
# see decorator for function body # see decorator for function body
def conjugate(x):
_x = as_sparse_variable(x)
if _x.type.dtype not in complex_dtypes:
return _x
return _conj(_x)
conj = conjugate
class TrueDot(Op): class TrueDot(Op):
# TODO # TODO
......
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
import scipy.sparse import scipy.sparse
import aesara import aesara
from aesara import scalar as aes
from aesara.graph.type import HasDataType from aesara.graph.type import HasDataType
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
...@@ -106,22 +107,28 @@ class SparseTensorType(TensorType, HasDataType): ...@@ -106,22 +107,28 @@ class SparseTensorType(TensorType, HasDataType):
and value.dtype == self.dtype and value.dtype == self.dtype
): ):
return value return value
if strict: if strict:
raise TypeError( raise TypeError(
f"{value} is not sparse, or not the right dtype (is {value.dtype}, " f"{value} is not sparse, or not the right dtype (is {value.dtype}, "
f"expected {self.dtype})" f"expected {self.dtype})"
) )
# The input format could be converted here # The input format could be converted here
if allow_downcast: if allow_downcast:
sp = self.format_cls[self.format](value, dtype=self.dtype) sp = self.format_cls[self.format](value, dtype=self.dtype)
else: else:
sp = self.format_cls[self.format](value) data = self.format_cls[self.format](value)
if str(sp.dtype) != self.dtype: up_dtype = aes.upcast(self.dtype, data.dtype)
if up_dtype != self.dtype:
raise NotImplementedError( raise NotImplementedError(
f"Expected {self.dtype} dtype but got {sp.dtype}" f"Expected {self.dtype} dtype but got {data.dtype}"
) )
sp = data.astype(up_dtype)
if sp.format != self.format: if sp.format != self.format:
raise NotImplementedError() raise NotImplementedError()
return sp return sp
@classmethod @classmethod
......
...@@ -1465,11 +1465,21 @@ def complex(real, imag): ...@@ -1465,11 +1465,21 @@ def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components""" """Return complex-valued tensor with `real` and `imag` components"""
@scalar_elemwise @scalar_elemwise(symbolname="conj")
def conj(z): def _conj(z):
"""Return the complex conjugate of `z`.""" """Return the complex conjugate of `z`."""
def conjugate(x):
_x = as_tensor_variable(x)
if _x.type.dtype not in complex_dtypes:
return _x
return _conj(_x)
conj = conjugate
@scalar_elemwise @scalar_elemwise
def complex_from_polar(abs, angle): def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification.""" """Return complex-valued tensor from polar coordinate specification."""
...@@ -2931,6 +2941,7 @@ __all__ = [ ...@@ -2931,6 +2941,7 @@ __all__ = [
"angle", "angle",
"complex", "complex",
"conj", "conj",
"conjugate",
"complex_from_polar", "complex_from_polar",
"sum", "sum",
"prod", "prod",
......
...@@ -58,6 +58,7 @@ from aesara.tensor.math import ( ...@@ -58,6 +58,7 @@ from aesara.tensor.math import (
Prod, Prod,
ProdWithoutZeros, ProdWithoutZeros,
Sum, Sum,
_conj,
) )
from aesara.tensor.math import abs as at_abs from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import ( from aesara.tensor.math import (
...@@ -86,6 +87,7 @@ from aesara.tensor.math import true_div ...@@ -86,6 +87,7 @@ from aesara.tensor.math import true_div
from aesara.tensor.shape import Shape, Shape_i from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import Subtensor from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
complex_dtypes,
uint_dtypes, uint_dtypes,
values_eq_approx_remove_inf, values_eq_approx_remove_inf,
values_eq_approx_remove_inf_nan, values_eq_approx_remove_inf_nan,
...@@ -3552,3 +3554,13 @@ local_sigmoid_logit = PatternSub( ...@@ -3552,3 +3554,13 @@ local_sigmoid_logit = PatternSub(
) )
register_canonicalize(local_sigmoid_logit) register_canonicalize(local_sigmoid_logit)
register_specialize(local_sigmoid_logit) register_specialize(local_sigmoid_logit)
@register_canonicalize
@register_useless
@local_optimizer([_conj])
def local_useless_conj(fgraph, node):
r"""Remove `conj` `Op`\s applied to non-imaginary variable types."""
x = node.inputs[0]
if x.type.dtype not in complex_dtypes:
return [x]
...@@ -3170,6 +3170,15 @@ SqrtTester = elemwise_checker(sparse.sqrt, np.sqrt, gap=(0, 10)) ...@@ -3170,6 +3170,15 @@ SqrtTester = elemwise_checker(sparse.sqrt, np.sqrt, gap=(0, 10))
ConjTester = elemwise_checker(sparse.conj, np.conj, grad_test=False) ConjTester = elemwise_checker(sparse.conj, np.conj, grad_test=False)
def test_useless_conj():
x = sparse.SparseTensorType("csr", dtype="complex128")()
assert x.conj() is not x
# No conjugate when the data type isn't complex
x = sparse.SparseTensorType("csr", dtype="float64")()
assert x.conj() is x
class TestMulSV: class TestMulSV:
def test_mul_s_v_grad(self): def test_mul_s_v_grad(self):
sp_types = {"csc": sp.sparse.csc_matrix, "csr": sp.sparse.csr_matrix} sp_types = {"csc": sp.sparse.csc_matrix, "csr": sp.sparse.csr_matrix}
......
...@@ -13,69 +13,72 @@ from aesara.tensor.type import DenseTensorType ...@@ -13,69 +13,72 @@ from aesara.tensor.type import DenseTensorType
class TestSparseVariable: class TestSparseVariable:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"method, exp_type, cm", "method, exp_type, cm, x",
[ [
("__abs__", DenseTensorType, None), ("__abs__", DenseTensorType, None, None),
("__neg__", SparseTensorType, ExitStack()), ("__neg__", SparseTensorType, ExitStack(), None),
("__ceil__", DenseTensorType, None), ("__ceil__", DenseTensorType, None, None),
("__floor__", DenseTensorType, None), ("__floor__", DenseTensorType, None, None),
("__trunc__", DenseTensorType, None), ("__trunc__", DenseTensorType, None, None),
("transpose", DenseTensorType, None), ("transpose", DenseTensorType, None, None),
("any", DenseTensorType, None), ("any", DenseTensorType, None, None),
("all", DenseTensorType, None), ("all", DenseTensorType, None, None),
("flatten", DenseTensorType, None), ("flatten", DenseTensorType, None, None),
("ravel", DenseTensorType, None), ("ravel", DenseTensorType, None, None),
("arccos", DenseTensorType, None), ("arccos", DenseTensorType, None, None),
("arcsin", DenseTensorType, None), ("arcsin", DenseTensorType, None, None),
("arctan", DenseTensorType, None), ("arctan", DenseTensorType, None, None),
("arccosh", DenseTensorType, None), ("arccosh", DenseTensorType, None, None),
("arcsinh", DenseTensorType, None), ("arcsinh", DenseTensorType, None, None),
("arctanh", DenseTensorType, None), ("arctanh", DenseTensorType, None, None),
("ceil", DenseTensorType, None), ("ceil", DenseTensorType, None, None),
("cos", DenseTensorType, None), ("cos", DenseTensorType, None, None),
("cosh", DenseTensorType, None), ("cosh", DenseTensorType, None, None),
("deg2rad", DenseTensorType, None), ("deg2rad", DenseTensorType, None, None),
("exp", DenseTensorType, None), ("exp", DenseTensorType, None, None),
("exp2", DenseTensorType, None), ("exp2", DenseTensorType, None, None),
("expm1", DenseTensorType, None), ("expm1", DenseTensorType, None, None),
("floor", DenseTensorType, None), ("floor", DenseTensorType, None, None),
("log", DenseTensorType, None), ("log", DenseTensorType, None, None),
("log10", DenseTensorType, None), ("log10", DenseTensorType, None, None),
("log1p", DenseTensorType, None), ("log1p", DenseTensorType, None, None),
("log2", DenseTensorType, None), ("log2", DenseTensorType, None, None),
("rad2deg", DenseTensorType, None), ("rad2deg", DenseTensorType, None, None),
("sin", DenseTensorType, None), ("sin", DenseTensorType, None, None),
("sinh", DenseTensorType, None), ("sinh", DenseTensorType, None, None),
("sqrt", DenseTensorType, None), ("sqrt", DenseTensorType, None, None),
("tan", DenseTensorType, None), ("tan", DenseTensorType, None, None),
("tanh", DenseTensorType, None), ("tanh", DenseTensorType, None, None),
("copy", DenseTensorType, None), ("copy", DenseTensorType, None, None),
("sum", DenseTensorType, ExitStack()), ("sum", DenseTensorType, ExitStack(), None),
("prod", DenseTensorType, None), ("prod", DenseTensorType, None, None),
("mean", DenseTensorType, None), ("mean", DenseTensorType, None, None),
("var", DenseTensorType, None), ("var", DenseTensorType, None, None),
("std", DenseTensorType, None), ("std", DenseTensorType, None, None),
("min", DenseTensorType, None), ("min", DenseTensorType, None, None),
("max", DenseTensorType, None), ("max", DenseTensorType, None, None),
("argmin", DenseTensorType, None), ("argmin", DenseTensorType, None, None),
("argmax", DenseTensorType, None), ("argmax", DenseTensorType, None, None),
("nonzero", DenseTensorType, ExitStack()), ("nonzero", DenseTensorType, ExitStack(), None),
("nonzero_values", DenseTensorType, None), ("nonzero_values", DenseTensorType, None, None),
("argsort", DenseTensorType, ExitStack()), ("argsort", DenseTensorType, ExitStack(), None),
("conj", DenseTensorType, None), ("conj", SparseTensorType, ExitStack(), at.cmatrix("x")),
("round", DenseTensorType, None), ("round", DenseTensorType, None, None),
("trace", DenseTensorType, None), ("trace", DenseTensorType, None, None),
("zeros_like", SparseTensorType, ExitStack()), ("zeros_like", SparseTensorType, ExitStack(), None),
("ones_like", DenseTensorType, ExitStack()), ("ones_like", DenseTensorType, ExitStack(), None),
("cumsum", DenseTensorType, None), ("cumsum", DenseTensorType, None, None),
("cumprod", DenseTensorType, None), ("cumprod", DenseTensorType, None, None),
("ptp", DenseTensorType, None), ("ptp", DenseTensorType, None, None),
("squeeze", DenseTensorType, None), ("squeeze", DenseTensorType, None, None),
("diagonal", DenseTensorType, None), ("diagonal", DenseTensorType, None, None),
], ],
) )
def test_unary(self, method, exp_type, cm): def test_unary(self, method, exp_type, cm, x):
if x is None:
x = at.dmatrix("x") x = at.dmatrix("x")
x = sparse.csr_from_dense(x) x = sparse.csr_from_dense(x)
method_to_call = getattr(x, method) method_to_call = getattr(x, method)
...@@ -98,7 +101,7 @@ class TestSparseVariable: ...@@ -98,7 +101,7 @@ class TestSparseVariable:
assert all(isinstance(out.type, exp_type) for out in z_outs) assert all(isinstance(out.type, exp_type) for out in z_outs)
f = aesara.function([x], z, on_unused_input="ignore") f = aesara.function([x], z, on_unused_input="ignore", allow_input_downcast=True)
res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
......
...@@ -696,7 +696,7 @@ TestComplexFromPolarBroadcast = makeBroadcastTester( ...@@ -696,7 +696,7 @@ TestComplexFromPolarBroadcast = makeBroadcastTester(
) )
TestConjBroadcast = makeBroadcastTester( TestConjBroadcast = makeBroadcastTester(
op=conj, expected=np.conj, good=_good_broadcast_unary_normal op=conj, expected=np.conj, good={"complex": _good_broadcast_unary_normal["complex"]}
) )
...@@ -2567,6 +2567,10 @@ class TestTensorInstanceMethods: ...@@ -2567,6 +2567,10 @@ class TestTensorInstanceMethods:
assert_array_equal(Z.conj().eval({Z: z}), z.conj()) assert_array_equal(Z.conj().eval({Z: z}), z.conj())
assert_array_equal(Z.conjugate().eval({Z: z}), z.conj()) assert_array_equal(Z.conjugate().eval({Z: z}), z.conj())
# No conjugate when the data type isn't complex
assert X.type.dtype not in complex_dtypes
assert X.conj() is X
def test_round(self): def test_round(self):
X, _ = self.vars X, _ = self.vars
x, _ = self.vals x, _ = self.vals
......
...@@ -34,7 +34,7 @@ from aesara.tensor.basic_opt import local_dimshuffle_lift ...@@ -34,7 +34,7 @@ from aesara.tensor.basic_opt import local_dimshuffle_lift
from aesara.tensor.blas import Dot22, Gemv from aesara.tensor.blas import Dot22, Gemv
from aesara.tensor.blas_c import CGemv from aesara.tensor.blas_c import CGemv
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import Dot, MaxAndArgmax, Prod, Sum from aesara.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
from aesara.tensor.math import abs as at_abs from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import add from aesara.tensor.math import add
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
...@@ -119,6 +119,7 @@ from aesara.tensor.type import ( ...@@ -119,6 +119,7 @@ from aesara.tensor.type import (
values_eq_approx_remove_nan, values_eq_approx_remove_nan,
vector, vector,
vectors, vectors,
zscalar,
) )
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -4619,3 +4620,20 @@ def test_local_logit_sigmoid(): ...@@ -4619,3 +4620,20 @@ def test_local_logit_sigmoid():
fg = optimize(FunctionGraph([x], [out])) fg = optimize(FunctionGraph([x], [out]))
assert not list(fg.toposort()) assert not list(fg.toposort())
assert fg.inputs[0] is fg.outputs[0] assert fg.inputs[0] is fg.outputs[0]
def test_local_useless_conj():
default_mode = get_default_mode()
# Test for all zeros
x = scalar()
s = _conj(x)
mode_with_opt = default_mode.including("canonicalization", "local_useless_conj")
f = function([x], s, mode=mode_with_opt)
assert not any(node.op == _conj for node in f.maker.fgraph.apply_nodes)
x = zscalar()
s = _conj(x)
mode_with_opt = default_mode.including("canonicalization", "local_useless_conj")
f = function([x], s, mode=mode_with_opt)
assert any(node.op == _conj for node in f.maker.fgraph.apply_nodes)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论