提交 7d1c9917 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: provide dtype in numpy reference implementation

上级 ed4e0679
......@@ -1934,8 +1934,9 @@ DotTester = makeTester(name='DotTester',
BatchedDotTester = makeTester(
name='BatchedDotTester',
op=batched_dot,
expected=lambda xs, ys: numpy.asarray(list(
numpy.dot(x, y) for x, y in zip(xs, ys))),
expected=lambda xs, ys: numpy.asarray(
list(numpy.dot(x, y) for x, y in zip(xs, ys)),
dtype=theano.scalar.upcast(xs.dtype, ys.dtype)),
checks={},
grad=dict(correct1=(rand(3, 5, 7), rand(3, 7, 5)),
correct2=(rand(3, 5, 7), rand(3, 7, 9)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论