提交 196346a0 authored 作者: notoraptor's avatar notoraptor

Also check red_op.

上级 df945124
...@@ -1626,8 +1626,9 @@ def test_dnn_reduction_axis_size_one(): ...@@ -1626,8 +1626,9 @@ def test_dnn_reduction_axis_size_one():
f2 = theano.function([x], sum_squares, mode=mode_with_gpu) f2 = theano.function([x], sum_squares, mode=mode_with_gpu)
f3 = theano.function([x], sum_abs, mode=mode_with_gpu) f3 = theano.function([x], sum_abs, mode=mode_with_gpu)
f4 = theano.function([x], absmax, mode=mode_with_gpu) f4 = theano.function([x], absmax, mode=mode_with_gpu)
for fn in (f1, f2, f3, f4): for fn, red_op in ((f1, 'add'), (f2, 'norm2'), (f3, 'norm1'), (f4, 'absmax')):
assert any(isinstance(node.op, dnn.GpuDnnReduction) for node in fn.maker.fgraph.apply_nodes) assert any(isinstance(node.op, dnn.GpuDnnReduction) and node.op.red_op == red_op
for node in fn.maker.fgraph.apply_nodes)
xval = np.random.uniform(-10, -1, size=shape).astype(dtype) xval = np.random.uniform(-10, -1, size=shape).astype(dtype)
xval_reshaped = xval.reshape(shape[:axis] + shape[(axis + 1):]) xval_reshaped = xval.reshape(shape[:axis] + shape[(axis + 1):])
test_val = abs(xval_reshaped) test_val = abs(xval_reshaped)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论