提交 5b5b21a6 authored 作者: Frederic Bastien's avatar Frederic Bastien

check that the input of TensorDot.__init__ and TensorDotGrad.__init__ are valid.

上级 b16882ca
...@@ -3766,6 +3766,13 @@ class TensorDotGrad(Op): ...@@ -3766,6 +3766,13 @@ class TensorDotGrad(Op):
if isinstance(a,list): if isinstance(a,list):
axes[i]=tuple(a) axes[i]=tuple(a)
axes=tuple(axes) axes=tuple(axes)
if isinstance(axes, tuple):
if len(axes)!=2:
raise ValueError("We need the list/tuple of axes to be of length 2")
if len(axes[0])!=len(axes[1]):
raise ValueError("We need that the axes 2 sub list of axes are of the same size")
assert len(axes[0])==len(axes[1])
self.axes = axes; self.axes = axes;
def __eq__(self, other): def __eq__(self, other):
...@@ -3818,7 +3825,14 @@ class TensorDot(Op): ...@@ -3818,7 +3825,14 @@ class TensorDot(Op):
if isinstance(a,list): if isinstance(a,list):
axes[i]=tuple(a) axes[i]=tuple(a)
axes=tuple(axes) axes=tuple(axes)
self.axes = axes; if isinstance(axes, tuple):
if len(axes)!=2:
raise ValueError("We need the list/tuple of axes to be of length 2")
if len(axes[0])!=len(axes[1]):
raise ValueError("We need that the axes 2 sub list of axes are of the same size")
assert len(axes[0])==len(axes[1])
self.axes = axes
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.axes == other.axes return type(self) == type(other) and self.axes == other.axes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论