提交 409a4b2a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Revert accidental structured sparse changes

*Remove `structured_prefix` for Ops that map 0->0 *Do not introduce new structured ops that don't map 0->0 besides the ones that already existed
上级 5c350ab2
......@@ -41,7 +41,7 @@ def structured_elemwise(tensor_op):
@structured_elemwise(ptm.abs)
def structured_abs(x):
def abs(x):
"""
Compute abs(x) for all non-zero elements of x.
"""
......@@ -61,13 +61,6 @@ def structured_exp(x):
"""
@structured_elemwise(ptm.exp2)
def structured_exp2(x):
"""
Compute exp2(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.log)
def structured_log(x):
"""
......@@ -75,20 +68,6 @@ def structured_log(x):
"""
@structured_elemwise(ptm.log2)
def structured_log2(x):
"""
Compute log2(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.log10)
def structured_log10(x):
"""
Compute log10(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.pow)
def structured_pow(x, y):
"""
......@@ -118,161 +97,133 @@ def structured_add(x, y):
@structured_elemwise(ptm.sin)
def structured_sin(x):
def sin(x):
"""
Compute sin(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.sinh)
def structured_sinh(x):
def sinh(x):
"""
Compute sinh(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.arcsin)
def structured_arcsin(x):
def arcsin(x):
"""
Compute arcsin(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.arcsinh)
def structured_arcsinh(x):
def arcsinh(x):
"""
Compute arcsinh(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.cos)
def structured_cos(x):
"""
Compute cos(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.cosh)
def structured_cosh(x):
"""
Compute cosh(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.arccos)
def structured_arccos(x):
"""
Compute arccos(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.arccosh)
def structured_arccosh(x):
"""
Compute arccosh(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.tan)
def structured_tan(x):
def tan(x):
"""
Compute tan(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.tanh)
def structured_tanh(x):
def tanh(x):
"""
Compute tanh(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.arctan)
def structured_arctan(x):
def arctan(x):
"""
Compute arctan(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.arctanh)
def structured_arctanh(x):
def arctanh(x):
"""
Compute arctanh(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.round_half_to_even)
def structured_rint(x):
def rint(x):
"""
Compute round_half_to_even(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.sign)
def structured_sign(x):
def sign(x):
"""
Compute sign(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.ceil)
def structured_ceil(x):
def ceil(x):
"""
Compute ceil(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.floor)
def structured_floor(x):
def floor(x):
"""
Compute floor(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.log1p)
def structured_log1p(x):
def log1p(x):
"""
Compute log(1 + x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.expm1)
def structured_expm1(x):
def expm1(x):
"""
Compute exp(x) - 1 for all non-zero elements of x.
"""
@structured_elemwise(ptm.deg2rad)
def structured_deg2rad(x):
def deg2rad(x):
"""
Convert degrees to radians for all non-zero elements of x.
"""
@structured_elemwise(ptm.rad2deg)
def structured_rad2deg(x):
def rad2deg(x):
"""
Convert radians to degrees for all non-zero elements of x.
"""
@structured_elemwise(ptm.trunc)
def structured_trunc(x):
def trunc(x):
"""
Truncate the decimal part of x for all non-zero elements of x.
"""
@structured_elemwise(ptm.sqr)
def structured_sqr(x):
def sqr(x):
"""
Compute sqr(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.sqrt)
def structured_sqrt(x):
def sqrt(x):
"""
Compute sqrt(x) for all non-zero elements of x.
"""
......@@ -292,7 +243,7 @@ def conjugate(x):
return _conj(_x)
structured_conjugate = conjugate
structured_conjugate = conj = conjugate
class SpSum(Op):
......
......@@ -1382,32 +1382,32 @@ StructuredAddTester = elemwise_checker(
name="StructuredAddTester",
)
SinTester = elemwise_checker(psm.structured_sin, np.sin)
SinTester = elemwise_checker(psm.sin, np.sin)
TanTester = elemwise_checker(psm.structured_tan, np.tan, gap=(-1, 1))
TanTester = elemwise_checker(psm.tan, np.tan, gap=(-1, 1))
ArcsinTester = elemwise_checker(
psm.structured_arcsin, np.arcsin, gap=(-1, 1), gap_grad=(-0.99, 0.99)
psm.arcsinh, np.arcsin, gap=(-1, 1), gap_grad=(-0.99, 0.99)
)
ArctanTester = elemwise_checker(psm.structured_arctan, np.arctan)
ArctanTester = elemwise_checker(psm.arctan, np.arctan)
SinhTester = elemwise_checker(psm.structured_sinh, np.sinh)
SinhTester = elemwise_checker(psm.sinh, np.sinh)
ArcsinhTester = elemwise_checker(psm.structured_arcsinh, np.arcsinh, gap=(-1, 1))
ArcsinhTester = elemwise_checker(psm.arcsinh, np.arcsinh, gap=(-1, 1))
TanhTester = elemwise_checker(psm.structured_tanh, np.tanh, gap=(-1, 1))
TanhTester = elemwise_checker(psm.tanh, np.tanh, gap=(-1, 1))
ArctanhTester = elemwise_checker(
psm.structured_arctanh, np.arctanh, gap=(-0.9, 1), gap_grad=(-0.9, 0.95)
psm.arctanh, np.arctanh, gap=(-0.9, 1), gap_grad=(-0.9, 0.95)
)
RintTester = elemwise_checker(
psm.structured_rint, np.rint, grad_test=False, test_dtypes=float_dtypes
psm.rint, np.rint, grad_test=False, test_dtypes=float_dtypes
)
SgnTester = elemwise_checker(
psm.structured_sign,
psm.sign,
np.sign,
grad_test=False,
test_dtypes=[
......@@ -1416,46 +1416,46 @@ SgnTester = elemwise_checker(
)
CeilTester = elemwise_checker(
psm.structured_ceil,
psm.ceil,
np.ceil,
grad_test=False,
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
)
FloorTester = elemwise_checker(
psm.structured_floor,
psm.floor,
np.floor,
grad_test=False,
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
)
Log1pTester = elemwise_checker(psm.structured_log1p, np.log1p, gap=(0.5, 10))
Log1pTester = elemwise_checker(psm.log1p, np.log1p, gap=(0.5, 10))
Expm1Tester = elemwise_checker(psm.structured_expm1, np.expm1)
Expm1Tester = elemwise_checker(psm.expm1, np.expm1)
Deg2radTester = elemwise_checker(
psm.structured_deg2rad,
psm.deg2rad,
np.deg2rad,
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
)
Rad2degTester = elemwise_checker(
psm.structured_rad2deg,
psm.rad2deg,
np.rad2deg,
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
)
TruncTester = elemwise_checker(
psm.structured_trunc,
psm.trunc,
np.trunc,
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
grad_test=False,
)
SqrTester = elemwise_checker(psm.structured_sqr, lambda x: x * x)
SqrTester = elemwise_checker(psm.sqr, lambda x: x * x)
SqrtTester = elemwise_checker(psm.structured_sqrt, np.sqrt, gap=(0, 10))
SqrtTester = elemwise_checker(psm.sqrt, np.sqrt, gap=(0, 10))
ConjTester = elemwise_checker(psm.structured_conjugate, np.conj, grad_test=False)
ConjTester = elemwise_checker(psm.conjugate, np.conj, grad_test=False)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论