提交 05204a8c authored 作者: Shawn Tan's avatar Shawn Tan

Implemented set state

上级 076056db
......@@ -6544,7 +6544,6 @@ class AllocDiag(Op):
__props__ = ("offset", "axis1", "axis2")
def __init__(self, offset=0, axis1=0, axis2=1):
self.view_map = {0: [0]}
self.offset = offset
self.axis1 = axis1
self.axis2 = axis2
......@@ -6580,6 +6579,7 @@ class AllocDiag(Op):
# Fill in final 2 axes with x
result[diagonal_slice] = x
if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = list(range(len(x.shape[:-1])))
......@@ -6610,6 +6610,20 @@ class AllocDiag(Op):
result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:]
return [tuple(result_shape)]
def __setstate__(self, state):
if "view_map" in state:
del state["view_map"]
self.__dict__.update(state)
if "offset" not in state:
self.offset = 0
if "axis1" not in state:
self.axis1 = 0
if "axis2" not in state:
self.axis2 = 1
def diag(v, k=0):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论