提交 d38dc060 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add QOL XRV methods

上级 60f78634
...@@ -198,6 +198,8 @@ class XRV(XOp, RNGConsumerOp): ...@@ -198,6 +198,8 @@ class XRV(XOp, RNGConsumerOp):
if len(set(extra_dims)) != len(extra_dims): if len(set(extra_dims)) != len(extra_dims):
raise ValueError("size_dims must be unique") raise ValueError("size_dims must be unique")
self.extra_dims = tuple(extra_dims) self.extra_dims = tuple(extra_dims)
if print_name := getattr(core_op, "_print_name", None):
self._print_name = print_name
def __str__(self): def __str__(self):
if self.name is not None: if self.name is not None:
...@@ -208,6 +210,15 @@ class XRV(XOp, RNGConsumerOp): ...@@ -208,6 +210,15 @@ class XRV(XOp, RNGConsumerOp):
attrs = f"(core_op={self.core_op}, core_dims={self.core_dims}, extra_dims={self.extra_dims})" attrs = f"(core_op={self.core_op}, core_dims={self.core_dims}, extra_dims={self.extra_dims})"
return f"{name}({attrs})" return f"{name}({attrs})"
def rng_param(self, node):
return node.inputs[0]
def size_params(self, node):
return node.inputs[1 : 1 + len(self.extra_dims)]
def dist_params(self, node):
return node.inputs[1 + len(self.extra_dims) :]
def update(self, node): def update(self, node):
# RNG input and update are the first input and output respectively # RNG input and update are the first input and output respectively
return {node.inputs[0]: node.outputs[0]} return {node.inputs[0]: node.outputs[0]}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论