提交 7d1c9917 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: provide dtype in numpy reference implementation

上级 ed4e0679
...@@ -1934,8 +1934,9 @@ DotTester = makeTester(name='DotTester', ...@@ -1934,8 +1934,9 @@ DotTester = makeTester(name='DotTester',
BatchedDotTester = makeTester( BatchedDotTester = makeTester(
name='BatchedDotTester', name='BatchedDotTester',
op=batched_dot, op=batched_dot,
expected=lambda xs, ys: numpy.asarray(list( expected=lambda xs, ys: numpy.asarray(
numpy.dot(x, y) for x, y in zip(xs, ys))), list(numpy.dot(x, y) for x, y in zip(xs, ys)),
dtype=theano.scalar.upcast(xs.dtype, ys.dtype)),
checks={}, checks={},
grad=dict(correct1=(rand(3, 5, 7), rand(3, 7, 5)), grad=dict(correct1=(rand(3, 5, 7), rand(3, 7, 5)),
correct2=(rand(3, 5, 7), rand(3, 7, 9)), correct2=(rand(3, 5, 7), rand(3, 7, 9)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论