提交 ecdb5d3f authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

ensure inputs are tensor variables

上级 c34bdacb
...@@ -7089,6 +7089,8 @@ def dot(a, b): ...@@ -7089,6 +7089,8 @@ def dot(a, b):
with a fast BLAS. with a fast BLAS.
""" """
a, b = as_tensor_variable(a), as_tensor_variable(b)
if a.ndim == 0 or b.ndim == 0: if a.ndim == 0 or b.ndim == 0:
return a * b return a * b
elif a.ndim > 2 or b.ndim > 2: elif a.ndim > 2 or b.ndim > 2:
...@@ -7188,7 +7190,7 @@ def tensordot(a, b, axes = 2): ...@@ -7188,7 +7190,7 @@ def tensordot(a, b, axes = 2):
See the documentation of np.tensordot for more examples. See the documentation of np.tensordot for more examples.
""" """
a, b = map(as_tensor_variable, (a, b)) a, b = as_tensor_variable(a), as_tensor_variable(b)
# axes must be a scalar or list/tuple of length 2 # axes must be a scalar or list/tuple of length 2
if not numpy.isscalar(axes) and len(axes) != 2: if not numpy.isscalar(axes) and len(axes) != 2:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论