提交 bee1a7a0 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: catch bad inputs

上级 91fe0a2b
......@@ -3563,7 +3563,11 @@ def batched_dot(a, b):
"""
a, b = as_tensor_variable(a), as_tensor_variable(b)
if a.ndim == 1:
if a.ndim == 0:
raise TypeError("a must have at least one (batch) axis")
elif b.ndim == 0:
raise TypeError("b must have at least one (batch) axis")
elif a.ndim == 1:
return a.dimshuffle(*([0] + ["x"] * (b.ndim - 1))) * b
elif b.ndim == 1:
return a * b.dimshuffle(*([0] + ["x"] * (a.ndim - 1)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论