提交 7f49bd67 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

check for non-variable inputs to tensordot

上级 329c19e6
...@@ -7232,6 +7232,8 @@ def tensordot(a, b, axes = 2): ...@@ -7232,6 +7232,8 @@ def tensordot(a, b, axes = 2):
See the documentation of np.tensordot for more examples. See the documentation of np.tensordot for more examples.
""" """
a, b = map(as_tensor_variable, (a, b))
# axes must be a scalar or list/tuple of length 2 # axes must be a scalar or list/tuple of length 2
if not numpy.isscalar(axes) and len(axes) != 2: if not numpy.isscalar(axes) and len(axes) != 2:
raise ValueError('Axes should be scalar valued or a ' raise ValueError('Axes should be scalar valued or a '
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论