Unverified 提交 f83c05ba authored 作者: Copilot's avatar Copilot 提交者: GitHub

Add `__rtruediv__` and `__rfloordiv__` to Scalar variables (#1701)

Also remove legacy `__div__`
上级 112b3252
......@@ -943,6 +943,12 @@ class _scalar_py_operators:
def __rmul__(self, other):
return mul(other, self)
def __rtruediv__(self, other):
return true_div(other, self)
def __rfloordiv__(self, other):
return int_div(other, self)
def __rmod__(self, other):
return mod(other, self)
......
......@@ -13,7 +13,6 @@ from pytensor.graph.basic import Constant, OptionalApplyType, Variable
from pytensor.graph.utils import MetaType
from pytensor.scalar import (
ComplexError,
IntegerDivisionError,
)
from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError
......@@ -138,18 +137,6 @@ class _tensor_py_operators:
except (NotImplementedError, TypeError):
return NotImplemented
def __div__(self, other):
# See explanation in __add__ for the error caught
# and the return value in that case
try:
return pt.math.div_proxy(self, other)
except IntegerDivisionError:
# This is to raise the exception that occurs when trying to divide
# two integer arrays (currently forbidden).
raise
except (NotImplementedError, TypeError):
return NotImplemented
def __pow__(self, other):
# See explanation in __add__ for the error caught
# and the return value in that case
......@@ -210,9 +197,6 @@ class _tensor_py_operators:
def __rmul__(self, other):
return pt.math.mul(other, self)
def __rdiv__(self, other):
return pt.math.div_proxy(other, self)
def __rmod__(self, other):
return pt.math.mod(other, self)
......
......@@ -308,9 +308,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def __mul__(self, other):
return px.math.mul(self, other)
def __div__(self, other):
return px.math.div(self, other)
def __pow__(self, other):
return px.math.pow(self, other)
......@@ -341,9 +338,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def __rmul__(self, other):
return px.math.mul(other, self)
def __rdiv__(self, other):
return px.math.div_proxy(other, self)
def __rmod__(self, other):
return px.math.mod(other, self)
......
......@@ -10,7 +10,9 @@ from pytensor.scalar.basic import (
EQ,
ComplexError,
Composite,
IntDiv,
ScalarType,
TrueDiv,
add,
and_,
arccos,
......@@ -531,3 +533,19 @@ def test_scalar_hash_default_output_type_preference():
del old_eq.output_types_preference # mimic old Op
assert new_eq == old_eq
assert hash(new_eq) == hash(old_eq)
def test_rtruediv():
x = ScalarType(dtype="float64")()
y = 1.0 / x
assert isinstance(y.owner.op, TrueDiv)
assert isinstance(y.type, ScalarType)
assert y.eval({x: 2.0}) == 0.5
def test_rfloordiv():
x = ScalarType(dtype="float64")()
y = 5.0 // x
assert isinstance(y.owner.op, IntDiv)
assert isinstance(y.type, ScalarType)
assert y.eval({x: 2.0}) == 2.0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论