提交 4c2d3478 authored 作者: Frederic's avatar Frederic

Add assert to detect error earlier.

上级 21a5a3c7
...@@ -963,6 +963,7 @@ class Eigh(Eig): ...@@ -963,6 +963,7 @@ class Eigh(Eig):
_numop = staticmethod(numpy.linalg.eigh) _numop = staticmethod(numpy.linalg.eigh)
def __init__(self, UPLO='L'): def __init__(self, UPLO='L'):
assert UPLO in ['L', 'U']
self.UPLO = UPLO self.UPLO = UPLO
def __str__(self): def __str__(self):
...@@ -1031,6 +1032,7 @@ class EighGrad(Op): ...@@ -1031,6 +1032,7 @@ class EighGrad(Op):
""" """
def __init__(self, UPLO='L'): def __init__(self, UPLO='L'):
assert UPLO in ['L', 'U']
self.UPLO = UPLO self.UPLO = UPLO
if UPLO == 'L': if UPLO == 'L':
self.tri0 = numpy.tril self.tri0 = numpy.tril
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论