Unverified 提交 f83c05ba authored 作者: Copilot's avatar Copilot 提交者: GitHub

Add `__rtruediv__` and `__rfloordiv__` to Scalar variables (#1701)

Also remove legacy `__div__`
上级 112b3252
...@@ -943,6 +943,12 @@ class _scalar_py_operators: ...@@ -943,6 +943,12 @@ class _scalar_py_operators:
def __rmul__(self, other): def __rmul__(self, other):
return mul(other, self) return mul(other, self)
def __rtruediv__(self, other):
return true_div(other, self)
def __rfloordiv__(self, other):
return int_div(other, self)
def __rmod__(self, other): def __rmod__(self, other):
return mod(other, self) return mod(other, self)
......
...@@ -13,7 +13,6 @@ from pytensor.graph.basic import Constant, OptionalApplyType, Variable ...@@ -13,7 +13,6 @@ from pytensor.graph.basic import Constant, OptionalApplyType, Variable
from pytensor.graph.utils import MetaType from pytensor.graph.utils import MetaType
from pytensor.scalar import ( from pytensor.scalar import (
ComplexError, ComplexError,
IntegerDivisionError,
) )
from pytensor.tensor import _get_vector_length from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError from pytensor.tensor.exceptions import AdvancedIndexingError
...@@ -138,18 +137,6 @@ class _tensor_py_operators: ...@@ -138,18 +137,6 @@ class _tensor_py_operators:
except (NotImplementedError, TypeError): except (NotImplementedError, TypeError):
return NotImplemented return NotImplemented
def __div__(self, other):
# See explanation in __add__ for the error caught
# and the return value in that case
try:
return pt.math.div_proxy(self, other)
except IntegerDivisionError:
# This is to raise the exception that occurs when trying to divide
# two integer arrays (currently forbidden).
raise
except (NotImplementedError, TypeError):
return NotImplemented
def __pow__(self, other): def __pow__(self, other):
# See explanation in __add__ for the error caught # See explanation in __add__ for the error caught
# and the return value in that case # and the return value in that case
...@@ -210,9 +197,6 @@ class _tensor_py_operators: ...@@ -210,9 +197,6 @@ class _tensor_py_operators:
def __rmul__(self, other): def __rmul__(self, other):
return pt.math.mul(other, self) return pt.math.mul(other, self)
def __rdiv__(self, other):
return pt.math.div_proxy(other, self)
def __rmod__(self, other): def __rmod__(self, other):
return pt.math.mod(other, self) return pt.math.mod(other, self)
......
...@@ -308,9 +308,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -308,9 +308,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def __mul__(self, other): def __mul__(self, other):
return px.math.mul(self, other) return px.math.mul(self, other)
def __div__(self, other):
return px.math.div(self, other)
def __pow__(self, other): def __pow__(self, other):
return px.math.pow(self, other) return px.math.pow(self, other)
...@@ -341,9 +338,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -341,9 +338,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def __rmul__(self, other): def __rmul__(self, other):
return px.math.mul(other, self) return px.math.mul(other, self)
def __rdiv__(self, other):
return px.math.div_proxy(other, self)
def __rmod__(self, other): def __rmod__(self, other):
return px.math.mod(other, self) return px.math.mod(other, self)
......
...@@ -10,7 +10,9 @@ from pytensor.scalar.basic import ( ...@@ -10,7 +10,9 @@ from pytensor.scalar.basic import (
EQ, EQ,
ComplexError, ComplexError,
Composite, Composite,
IntDiv,
ScalarType, ScalarType,
TrueDiv,
add, add,
and_, and_,
arccos, arccos,
...@@ -531,3 +533,19 @@ def test_scalar_hash_default_output_type_preference(): ...@@ -531,3 +533,19 @@ def test_scalar_hash_default_output_type_preference():
del old_eq.output_types_preference # mimic old Op del old_eq.output_types_preference # mimic old Op
assert new_eq == old_eq assert new_eq == old_eq
assert hash(new_eq) == hash(old_eq) assert hash(new_eq) == hash(old_eq)
def test_rtruediv():
x = ScalarType(dtype="float64")()
y = 1.0 / x
assert isinstance(y.owner.op, TrueDiv)
assert isinstance(y.type, ScalarType)
assert y.eval({x: 2.0}) == 0.5
def test_rfloordiv():
x = ScalarType(dtype="float64")()
y = 5.0 // x
assert isinstance(y.owner.op, IntDiv)
assert isinstance(y.type, ScalarType)
assert y.eval({x: 2.0}) == 2.0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论