提交 9f88e1fc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

OpFromGraph subclasses shouldn't have __props__

When specified, Ops with identical __props__ are considered identical, in that they can be swapped and given the original inputs to obtain the same output.
上级 1509ceef
......@@ -3780,8 +3780,6 @@ class AllocDiag(OpFromGraph):
Wrapper Op for alloc_diag graphs
"""
__props__ = ("axis1", "axis2")
def __init__(self, *args, axis1, axis2, offset, **kwargs):
self.axis1 = axis1
self.axis2 = axis2
......@@ -3789,6 +3787,9 @@ class AllocDiag(OpFromGraph):
super().__init__(*args, **kwargs, strict=True)
def __str__(self):
return f"AllocDiag{{{self.axis1=}, {self.axis2=}, {self.offset=}}}"
@staticmethod
def is_offset_zero(node) -> bool:
"""
......
......@@ -52,14 +52,15 @@ class Einsum(OpFromGraph):
desired. We haven't decided whether we want to provide this functionality.
"""
__props__ = ("subscripts", "path", "optimized")
def __init__(self, *args, subscripts: str, path: PATH, optimized: bool, **kwargs):
self.subscripts = subscripts
self.path = path
self.optimized = optimized
super().__init__(*args, **kwargs, strict=True)
def __str__(self):
return f"Einsum{{{self.subscripts=}, {self.path=}, {self.optimized=}}}"
def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
"""
......
......@@ -37,6 +37,7 @@ from pytensor.tensor.basic import (
TensorFromScalar,
Tri,
alloc,
alloc_diag,
arange,
as_tensor_variable,
atleast_Nd,
......@@ -3793,6 +3794,18 @@ class TestAllocDiag:
)
assert np.all(true_grad_input == grad_input)
def test_multiple_ops_same_graph(self):
"""Regression test when AllocDiag OFG was given insufficient props, causing incompatible Ops to be merged."""
v1 = vector("v1", shape=(2,), dtype="float64")
v2 = vector("v2", shape=(3,), dtype="float64")
a1 = alloc_diag(v1)
a2 = alloc_diag(v2)
fn = function([v1, v2], [a1, a2])
res1, res2 = fn(v1=[np.e, np.e], v2=[np.pi, np.pi, np.pi])
np.testing.assert_allclose(res1, np.eye(2) * np.e)
np.testing.assert_allclose(res2, np.eye(3) * np.pi)
def test_diagonal_negative_axis():
x = np.arange(2 * 3 * 3).reshape((2, 3, 3))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论