提交 96122d15 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix backwards compatibility in ScalarOp hash

上级 ee107cba
...@@ -1294,13 +1294,12 @@ class ScalarOp(COp): ...@@ -1294,13 +1294,12 @@ class ScalarOp(COp):
return self.grad(inputs, output_gradients) return self.grad(inputs, output_gradients)
def __eq__(self, other): def __eq__(self, other):
test = type(self) is type(other) and getattr( return type(self) is type(other) and getattr(
self, "output_types_preference", None self, "output_types_preference", None
) == getattr(other, "output_types_preference", None) ) == getattr(other, "output_types_preference", None)
return test
def __hash__(self): def __hash__(self):
return hash((type(self), getattr(self, "output_types_preference", 0))) return hash((type(self), getattr(self, "output_types_preference", None)))
def __str__(self): def __str__(self):
if hasattr(self, "name") and self.name: if hasattr(self, "name") and self.name:
......
...@@ -8,6 +8,7 @@ from pytensor.compile.mode import Mode ...@@ -8,6 +8,7 @@ from pytensor.compile.mode import Mode
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.link.c.basic import DualLinker from pytensor.link.c.basic import DualLinker
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
EQ,
ComplexError, ComplexError,
Composite, Composite,
InRange, InRange,
...@@ -543,3 +544,18 @@ def test_grad_log10(): ...@@ -543,3 +544,18 @@ def test_grad_log10():
b_grad = pytensor.gradient.grad(b, a) b_grad = pytensor.gradient.grad(b, a)
assert b.dtype == "float32" assert b.dtype == "float32"
assert b_grad.dtype == "float32" assert b_grad.dtype == "float32"
def test_scalar_hash_default_output_type_preference():
# Old hash used `getattr(self, "output_type_preference", 0)`
# whereas equality used `getattr(self, "output_type_preference", None)`.
# Since 27d797076668fbf0617654fd9b91f92ddb6737e6,
# output_type_preference is always present (None if not specified),
# which led to C-caching errors when comparing old cached Ops and fresh Ops,
# as they evaluated equal but hashed differently
new_eq = EQ()
old_eq = EQ()
del old_eq.output_types_preference # mimic old Op
assert new_eq == old_eq
assert hash(new_eq) == hash(old_eq)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论