提交 5aa17801 authored 作者: Will Dean's avatar Will Dean 提交者: Rémi Louf

Add name argument to clone methods via kwargs

上级 ac5fb058
...@@ -167,9 +167,10 @@ class SharedVariable(Variable): ...@@ -167,9 +167,10 @@ class SharedVariable(Variable):
else: else:
self.container.value = 0 * self.container.value self.container.value = 0 * self.container.value
def clone(self): def clone(self, **kwargs):
name = kwargs.get("name", self.name)
cp = self.__class__( cp = self.__class__(
name=self.name, name=name,
type=self.type, type=self.type,
value=None, value=None,
strict=None, strict=None,
......
...@@ -510,9 +510,14 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]): ...@@ -510,9 +510,14 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
pass pass
return "\n".join(to_print) return "\n".join(to_print)
def clone(self): def clone(self, **kwargs):
"""Return a new, un-owned `Variable` like `self`. """Return a new, un-owned `Variable` like `self`.
Parameters
----------
**kwargs : dict
Optional "name" keyword argument for the copied instance. Same as `self.name` if value not provided.
Returns Returns
------- -------
Variable instance Variable instance
...@@ -523,7 +528,8 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]): ...@@ -523,7 +528,8 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
Tags and names are copied to the returned instance. Tags and names are copied to the returned instance.
""" """
cp = self.__class__(self.type, None, None, self.name) name = kwargs.pop("name", self.name)
cp = self.__class__(type=self.type, owner=None, index=None, name=name, **kwargs)
cp.tag = copy(self.tag) cp.tag = copy(self.tag)
return cp return cp
...@@ -621,8 +627,8 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]): ...@@ -621,8 +627,8 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
class AtomicVariable(Variable[_TypeType, None]): class AtomicVariable(Variable[_TypeType, None]):
"""A node type that has no ancestors and should never be considered an input to a graph.""" """A node type that has no ancestors and should never be considered an input to a graph."""
def __init__(self, type: _TypeType, **kwargs): def __init__(self, type: _TypeType, name: Optional[str] = None, **kwargs):
super().__init__(type, None, None, **kwargs) super().__init__(type=type, owner=None, index=None, name=name, **kwargs)
@abc.abstractmethod @abc.abstractmethod
def signature(self): def signature(self):
...@@ -656,6 +662,12 @@ class AtomicVariable(Variable[_TypeType, None]): ...@@ -656,6 +662,12 @@ class AtomicVariable(Variable[_TypeType, None]):
if value is not None: if value is not None:
raise ValueError("AtomicVariable instances cannot have an index.") raise ValueError("AtomicVariable instances cannot have an index.")
def clone(self, **kwargs):
name = kwargs.pop("name", self.name)
cp = self.__class__(type=self.type, name=name, **kwargs)
cp.tag = copy(self.tag)
return cp
class NominalVariable(AtomicVariable[_TypeType]): class NominalVariable(AtomicVariable[_TypeType]):
"""A variable that enables alpha-equivalent comparisons.""" """A variable that enables alpha-equivalent comparisons."""
...@@ -682,12 +694,13 @@ class NominalVariable(AtomicVariable[_TypeType]): ...@@ -682,12 +694,13 @@ class NominalVariable(AtomicVariable[_TypeType]):
return cls.__instances__[(typ, id)] return cls.__instances__[(typ, id)]
def __init__(self, id: _IdType, typ: _TypeType, **kwargs): def __init__(self, id: _IdType, typ: _TypeType, name: Optional[str] = None):
self.id = id self.id = id
super().__init__(typ, **kwargs) super().__init__(type=typ, name=name)
def clone(self): def clone(self, **kwargs):
return self name = kwargs.pop("name", self.name)
return self.__class__(id=self.id, typ=self.type, name=name, **kwargs)
def __eq__(self, other): def __eq__(self, other):
if self is other: if self is other:
...@@ -744,8 +757,7 @@ class Constant(AtomicVariable[_TypeType]): ...@@ -744,8 +757,7 @@ class Constant(AtomicVariable[_TypeType]):
name = name[:10] + "..." + name[-10:] name = name[:10] + "..." + name[-10:]
return f"{type(self).__name__}{{{name}}}" return f"{type(self).__name__}{{{name}}}"
def clone(self): def clone(self, **kwargs):
"""Return `self`, because there's no reason to clone a constant."""
return self return self
@property @property
......
...@@ -354,6 +354,12 @@ class TestAutoName: ...@@ -354,6 +354,12 @@ class TestAutoName:
assert r1.auto_name == "auto_" + str(autoname_id) assert r1.auto_name == "auto_" + str(autoname_id)
assert r2.auto_name == "auto_" + str(autoname_id + 1) assert r2.auto_name == "auto_" + str(autoname_id + 1)
assert r1.name is None and r1.name is r2.name
r3_name = "r3"
r3 = r1.clone(name=r3_name)
assert r3.name == r3_name
def test_equal_computations(): def test_equal_computations():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论