Unverified 提交 60ba7c71 authored 作者: Copilot's avatar Copilot 提交者: GitHub

Remove InRange scalar Op (#1699)

* Initial plan * Remove InRange scalar op and add missing __rtruediv__ and __rfloordiv__ * Remove __rtruediv__ and __rfloordiv__ changes per review feedback --------- Co-authored-by: 's avatarcopilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
上级 b9af9523
...@@ -1464,7 +1464,6 @@ class ProfileStats: ...@@ -1464,7 +1464,6 @@ class ProfileStats:
ps.GE, ps.GE,
ps.EQ, ps.EQ,
ps.NEQ, ps.NEQ,
ps.InRange,
ps.Switch, ps.Switch,
ps.OR, ps.OR,
ps.XOR, ps.XOR,
......
...@@ -1650,56 +1650,6 @@ class IsInf(FixedLogicalComparison): ...@@ -1650,56 +1650,6 @@ class IsInf(FixedLogicalComparison):
isinf = IsInf() isinf = IsInf()
class InRange(LogicalComparison):
nin = 3
def __init__(self, openlow, openhi):
self.openlow = openlow
self.openhi = openhi
def impl(self, x, low, hi):
if self.openlow and x <= low:
return False
elif not self.openlow and x < low:
return False
if self.openhi and x >= hi:
return False
elif not self.openhi and x > hi:
return False
return True
def c_code(self, node, name, inputs, outputs, sub):
(x, low, hi) = inputs
(z,) = outputs
cmp1 = ">" if self.openlow else ">="
cmp2 = "<" if self.openhi else "<="
return f"{z} = {x} {cmp1} {low} && {x} {cmp2} {hi};"
def get_grad(self, elem):
if elem.type in complex_types:
msg = (
"No gradient implemented for complex numbers in "
"class scalar.basic.InRange"
)
raise NotImplementedError(msg)
elif elem.type in discrete_types:
return elem.zeros_like(dtype=config.floatX)
else:
return elem.zeros_like()
def L_op(self, inputs, outputs, gout):
(x, low, hi) = inputs
(_gz,) = gout
grads = [self.get_grad(elem) for elem in [x, low, hi]]
return grads
inopenrange = InRange(True, True)
inclosedrange = InRange(False, False)
class Switch(ScalarOp): class Switch(ScalarOp):
nin = 3 nin = 3
nfunc_spec = ("where", 3, 1) nfunc_spec = ("where", 3, 1)
......
...@@ -3,7 +3,6 @@ import pytest ...@@ -3,7 +3,6 @@ import pytest
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
import tests.unittest_tools as utt
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.link.c.basic import DualLinker from pytensor.link.c.basic import DualLinker
...@@ -11,7 +10,6 @@ from pytensor.scalar.basic import ( ...@@ -11,7 +10,6 @@ from pytensor.scalar.basic import (
EQ, EQ,
ComplexError, ComplexError,
Composite, Composite,
InRange,
ScalarType, ScalarType,
add, add,
and_, and_,
...@@ -475,32 +473,6 @@ def test_grad_identity(): ...@@ -475,32 +473,6 @@ def test_grad_identity():
pytensor.gradient.grad(l, x) pytensor.gradient.grad(l, x)
def test_grad_inrange():
for bound_definition in [(True, True), (False, False)]:
# Instantiate op, and then take the gradient
op = InRange(*bound_definition)
x = fscalar("x")
low = fscalar("low")
high = fscalar("high")
out = op(x, low, high)
gx, glow, ghigh = pytensor.gradient.grad(out, [x, low, high])
# We look if the gradient are equal to zero
# if x is lower than the lower bound,
# equal to the lower bound, between lower and higher bound,
# equal to the higher bound and higher than the higher
# bound.
# Mathematically we should have an infinite gradient when
# x is equal to the lower or higher bound but in that case
# PyTensor defines the gradient to be zero for stability.
f = pytensor.function([x, low, high], [gx, glow, ghigh])
utt.assert_allclose(f(0, 1, 5), [0, 0, 0])
utt.assert_allclose(f(1, 1, 5), [0, 0, 0])
utt.assert_allclose(f(2, 1, 5), [0, 0, 0])
utt.assert_allclose(f(5, 1, 5), [0, 0, 0])
utt.assert_allclose(f(7, 1, 5), [0, 0, 0])
def test_grad_abs(): def test_grad_abs():
a = fscalar("a") a = fscalar("a")
b = 0.5 * (a + pytensor.tensor.abs(a)) b = 0.5 * (a + pytensor.tensor.abs(a))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论