提交 409a4b2a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Revert accidental structured sparse changes

*Remove `structured_prefix` for Ops that map 0->0 *Do not introduce new structured ops that don't map 0->0 besides the ones that already existed
上级 5c350ab2
...@@ -41,7 +41,7 @@ def structured_elemwise(tensor_op): ...@@ -41,7 +41,7 @@ def structured_elemwise(tensor_op):
@structured_elemwise(ptm.abs) @structured_elemwise(ptm.abs)
def structured_abs(x): def abs(x):
""" """
Compute abs(x) for all non-zero elements of x. Compute abs(x) for all non-zero elements of x.
""" """
...@@ -61,13 +61,6 @@ def structured_exp(x): ...@@ -61,13 +61,6 @@ def structured_exp(x):
""" """
@structured_elemwise(ptm.exp2)
def structured_exp2(x):
"""
Compute exp2(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.log) @structured_elemwise(ptm.log)
def structured_log(x): def structured_log(x):
""" """
...@@ -75,20 +68,6 @@ def structured_log(x): ...@@ -75,20 +68,6 @@ def structured_log(x):
""" """
@structured_elemwise(ptm.log2)
def structured_log2(x):
"""
Compute log2(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.log10)
def structured_log10(x):
"""
Compute log10(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.pow) @structured_elemwise(ptm.pow)
def structured_pow(x, y): def structured_pow(x, y):
""" """
...@@ -118,161 +97,133 @@ def structured_add(x, y): ...@@ -118,161 +97,133 @@ def structured_add(x, y):
@structured_elemwise(ptm.sin) @structured_elemwise(ptm.sin)
def structured_sin(x): def sin(x):
""" """
Compute sin(x) for all non-zero elements of x. Compute sin(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.sinh) @structured_elemwise(ptm.sinh)
def structured_sinh(x): def sinh(x):
""" """
Compute sinh(x) for all non-zero elements of x. Compute sinh(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.arcsin) @structured_elemwise(ptm.arcsin)
def structured_arcsin(x): def arcsin(x):
""" """
Compute arcsin(x) for all non-zero elements of x. Compute arcsin(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.arcsinh) @structured_elemwise(ptm.arcsinh)
def structured_arcsinh(x): def arcsinh(x):
""" """
Compute arcsinh(x) for all non-zero elements of x. Compute arcsinh(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.cos)
def structured_cos(x):
"""
Compute cos(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.cosh)
def structured_cosh(x):
"""
Compute cosh(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.arccos)
def structured_arccos(x):
"""
Compute arccos(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.arccosh)
def structured_arccosh(x):
"""
Compute arccosh(x) for all non-zero elements of x.
"""
@structured_elemwise(ptm.tan) @structured_elemwise(ptm.tan)
def structured_tan(x): def tan(x):
""" """
Compute tan(x) for all non-zero elements of x. Compute tan(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.tanh) @structured_elemwise(ptm.tanh)
def structured_tanh(x): def tanh(x):
""" """
Compute tanh(x) for all non-zero elements of x. Compute tanh(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.arctan) @structured_elemwise(ptm.arctan)
def structured_arctan(x): def arctan(x):
""" """
Compute arctan(x) for all non-zero elements of x. Compute arctan(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.arctanh) @structured_elemwise(ptm.arctanh)
def structured_arctanh(x): def arctanh(x):
""" """
Compute arctanh(x) for all non-zero elements of x. Compute arctanh(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.round_half_to_even) @structured_elemwise(ptm.round_half_to_even)
def structured_rint(x): def rint(x):
""" """
Compute round_half_to_even(x) for all non-zero elements of x. Compute round_half_to_even(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.sign) @structured_elemwise(ptm.sign)
def structured_sign(x): def sign(x):
""" """
Compute sign(x) for all non-zero elements of x. Compute sign(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.ceil) @structured_elemwise(ptm.ceil)
def structured_ceil(x): def ceil(x):
""" """
Compute ceil(x) for all non-zero elements of x. Compute ceil(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.floor) @structured_elemwise(ptm.floor)
def structured_floor(x): def floor(x):
""" """
Compute floor(x) for all non-zero elements of x. Compute floor(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.log1p) @structured_elemwise(ptm.log1p)
def structured_log1p(x): def log1p(x):
""" """
Compute log(1 + x) for all non-zero elements of x. Compute log(1 + x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.expm1) @structured_elemwise(ptm.expm1)
def structured_expm1(x): def expm1(x):
""" """
Compute exp(x) - 1 for all non-zero elements of x. Compute exp(x) - 1 for all non-zero elements of x.
""" """
@structured_elemwise(ptm.deg2rad) @structured_elemwise(ptm.deg2rad)
def structured_deg2rad(x): def deg2rad(x):
""" """
Convert degrees to radians for all non-zero elements of x. Convert degrees to radians for all non-zero elements of x.
""" """
@structured_elemwise(ptm.rad2deg) @structured_elemwise(ptm.rad2deg)
def structured_rad2deg(x): def rad2deg(x):
""" """
Convert radians to degrees for all non-zero elements of x. Convert radians to degrees for all non-zero elements of x.
""" """
@structured_elemwise(ptm.trunc) @structured_elemwise(ptm.trunc)
def structured_trunc(x): def trunc(x):
""" """
Truncate the decimal part of x for all non-zero elements of x. Truncate the decimal part of x for all non-zero elements of x.
""" """
@structured_elemwise(ptm.sqr) @structured_elemwise(ptm.sqr)
def structured_sqr(x): def sqr(x):
""" """
Compute sqr(x) for all non-zero elements of x. Compute sqr(x) for all non-zero elements of x.
""" """
@structured_elemwise(ptm.sqrt) @structured_elemwise(ptm.sqrt)
def structured_sqrt(x): def sqrt(x):
""" """
Compute sqrt(x) for all non-zero elements of x. Compute sqrt(x) for all non-zero elements of x.
""" """
...@@ -292,7 +243,7 @@ def conjugate(x): ...@@ -292,7 +243,7 @@ def conjugate(x):
return _conj(_x) return _conj(_x)
structured_conjugate = conjugate structured_conjugate = conj = conjugate
class SpSum(Op): class SpSum(Op):
......
...@@ -1382,32 +1382,32 @@ StructuredAddTester = elemwise_checker( ...@@ -1382,32 +1382,32 @@ StructuredAddTester = elemwise_checker(
name="StructuredAddTester", name="StructuredAddTester",
) )
SinTester = elemwise_checker(psm.structured_sin, np.sin) SinTester = elemwise_checker(psm.sin, np.sin)
TanTester = elemwise_checker(psm.structured_tan, np.tan, gap=(-1, 1)) TanTester = elemwise_checker(psm.tan, np.tan, gap=(-1, 1))
ArcsinTester = elemwise_checker( ArcsinTester = elemwise_checker(
psm.structured_arcsin, np.arcsin, gap=(-1, 1), gap_grad=(-0.99, 0.99) psm.arcsinh, np.arcsin, gap=(-1, 1), gap_grad=(-0.99, 0.99)
) )
ArctanTester = elemwise_checker(psm.structured_arctan, np.arctan) ArctanTester = elemwise_checker(psm.arctan, np.arctan)
SinhTester = elemwise_checker(psm.structured_sinh, np.sinh) SinhTester = elemwise_checker(psm.sinh, np.sinh)
ArcsinhTester = elemwise_checker(psm.structured_arcsinh, np.arcsinh, gap=(-1, 1)) ArcsinhTester = elemwise_checker(psm.arcsinh, np.arcsinh, gap=(-1, 1))
TanhTester = elemwise_checker(psm.structured_tanh, np.tanh, gap=(-1, 1)) TanhTester = elemwise_checker(psm.tanh, np.tanh, gap=(-1, 1))
ArctanhTester = elemwise_checker( ArctanhTester = elemwise_checker(
psm.structured_arctanh, np.arctanh, gap=(-0.9, 1), gap_grad=(-0.9, 0.95) psm.arctanh, np.arctanh, gap=(-0.9, 1), gap_grad=(-0.9, 0.95)
) )
RintTester = elemwise_checker( RintTester = elemwise_checker(
psm.structured_rint, np.rint, grad_test=False, test_dtypes=float_dtypes psm.rint, np.rint, grad_test=False, test_dtypes=float_dtypes
) )
SgnTester = elemwise_checker( SgnTester = elemwise_checker(
psm.structured_sign, psm.sign,
np.sign, np.sign,
grad_test=False, grad_test=False,
test_dtypes=[ test_dtypes=[
...@@ -1416,46 +1416,46 @@ SgnTester = elemwise_checker( ...@@ -1416,46 +1416,46 @@ SgnTester = elemwise_checker(
) )
CeilTester = elemwise_checker( CeilTester = elemwise_checker(
psm.structured_ceil, psm.ceil,
np.ceil, np.ceil,
grad_test=False, grad_test=False,
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes], test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
) )
FloorTester = elemwise_checker( FloorTester = elemwise_checker(
psm.structured_floor, psm.floor,
np.floor, np.floor,
grad_test=False, grad_test=False,
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes], test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
) )
Log1pTester = elemwise_checker(psm.structured_log1p, np.log1p, gap=(0.5, 10)) Log1pTester = elemwise_checker(psm.log1p, np.log1p, gap=(0.5, 10))
Expm1Tester = elemwise_checker(psm.structured_expm1, np.expm1) Expm1Tester = elemwise_checker(psm.expm1, np.expm1)
Deg2radTester = elemwise_checker( Deg2radTester = elemwise_checker(
psm.structured_deg2rad, psm.deg2rad,
np.deg2rad, np.deg2rad,
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes], test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
) )
Rad2degTester = elemwise_checker( Rad2degTester = elemwise_checker(
psm.structured_rad2deg, psm.rad2deg,
np.rad2deg, np.rad2deg,
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes], test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
) )
TruncTester = elemwise_checker( TruncTester = elemwise_checker(
psm.structured_trunc, psm.trunc,
np.trunc, np.trunc,
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes], test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
grad_test=False, grad_test=False,
) )
SqrTester = elemwise_checker(psm.structured_sqr, lambda x: x * x) SqrTester = elemwise_checker(psm.sqr, lambda x: x * x)
SqrtTester = elemwise_checker(psm.structured_sqrt, np.sqrt, gap=(0, 10)) SqrtTester = elemwise_checker(psm.sqrt, np.sqrt, gap=(0, 10))
ConjTester = elemwise_checker(psm.structured_conjugate, np.conj, grad_test=False) ConjTester = elemwise_checker(psm.conjugate, np.conj, grad_test=False)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论