提交 5aa17801 authored 作者: Will Dean's avatar Will Dean 提交者: Rémi Louf

Add name argument to clone methods via kwargs

上级 ac5fb058
......@@ -167,9 +167,10 @@ class SharedVariable(Variable):
else:
self.container.value = 0 * self.container.value
def clone(self):
def clone(self, **kwargs):
name = kwargs.get("name", self.name)
cp = self.__class__(
name=self.name,
name=name,
type=self.type,
value=None,
strict=None,
......
......@@ -510,9 +510,14 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
pass
return "\n".join(to_print)
def clone(self):
def clone(self, **kwargs):
"""Return a new, un-owned `Variable` like `self`.
Parameters
----------
**kwargs : dict
Optional "name" keyword argument for the copied instance. Same as `self.name` if value not provided.
Returns
-------
Variable instance
......@@ -523,7 +528,8 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
Tags and names are copied to the returned instance.
"""
cp = self.__class__(self.type, None, None, self.name)
name = kwargs.pop("name", self.name)
cp = self.__class__(type=self.type, owner=None, index=None, name=name, **kwargs)
cp.tag = copy(self.tag)
return cp
......@@ -621,8 +627,8 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
class AtomicVariable(Variable[_TypeType, None]):
"""A node type that has no ancestors and should never be considered an input to a graph."""
def __init__(self, type: _TypeType, **kwargs):
super().__init__(type, None, None, **kwargs)
def __init__(self, type: _TypeType, name: Optional[str] = None, **kwargs):
super().__init__(type=type, owner=None, index=None, name=name, **kwargs)
@abc.abstractmethod
def signature(self):
......@@ -656,6 +662,12 @@ class AtomicVariable(Variable[_TypeType, None]):
if value is not None:
raise ValueError("AtomicVariable instances cannot have an index.")
def clone(self, **kwargs):
name = kwargs.pop("name", self.name)
cp = self.__class__(type=self.type, name=name, **kwargs)
cp.tag = copy(self.tag)
return cp
class NominalVariable(AtomicVariable[_TypeType]):
"""A variable that enables alpha-equivalent comparisons."""
......@@ -682,12 +694,13 @@ class NominalVariable(AtomicVariable[_TypeType]):
return cls.__instances__[(typ, id)]
def __init__(self, id: _IdType, typ: _TypeType, **kwargs):
def __init__(self, id: _IdType, typ: _TypeType, name: Optional[str] = None):
self.id = id
super().__init__(typ, **kwargs)
super().__init__(type=typ, name=name)
def clone(self):
return self
def clone(self, **kwargs):
name = kwargs.pop("name", self.name)
return self.__class__(id=self.id, typ=self.type, name=name, **kwargs)
def __eq__(self, other):
if self is other:
......@@ -744,8 +757,7 @@ class Constant(AtomicVariable[_TypeType]):
name = name[:10] + "..." + name[-10:]
return f"{type(self).__name__}{{{name}}}"
def clone(self):
"""Return `self`, because there's no reason to clone a constant."""
def clone(self, **kwargs):
return self
@property
......
......@@ -354,6 +354,12 @@ class TestAutoName:
assert r1.auto_name == "auto_" + str(autoname_id)
assert r2.auto_name == "auto_" + str(autoname_id + 1)
assert r1.name is None and r1.name is r2.name
r3_name = "r3"
r3 = r1.clone(name=r3_name)
assert r3.name == r3_name
def test_equal_computations():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论