提交 0fde9a49 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: simplify reference implementation

上级 1075d838
...@@ -1937,9 +1937,7 @@ BatchedDotTester = makeTester( ...@@ -1937,9 +1937,7 @@ BatchedDotTester = makeTester(
op=batched_dot, op=batched_dot,
expected=(lambda xs, ys: expected=(lambda xs, ys:
numpy.asarray( numpy.asarray(
list(x * y if x.ndim == 0 or y.ndim == 0 list(x * y if x.ndim == 0 or y.ndim == 0 else numpy.dot(x, y)
else (numpy.dot(x, y) if y.ndim == 1
else numpy.tensordot(x, y, [[x.ndim - 1], [y.ndim - 2]]))
for x, y in zip(xs, ys)), for x, y in zip(xs, ys)),
dtype=theano.scalar.upcast(xs.dtype, ys.dtype))), dtype=theano.scalar.upcast(xs.dtype, ys.dtype))),
checks={}, checks={},
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论