提交 da66c2ef authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Fix SoftmaxGrad failure with constant dy in numba backend

上级 fc5e10f7
......@@ -925,7 +925,12 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
dx = dy_times_sm - sum_dy_times_sm * sm
return dx
softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
# The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True)
# softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
softmax_grad = numba_njit(
boundscheck=False,
fastmath=config.numba__fastmath,
)(softmax_grad_py_fn)
return softmax_grad
......
......@@ -445,6 +445,16 @@ def test_SoftmaxGrad(dy, sm, axis, exc):
)
def test_SoftMaxGrad_constant_dy():
dy = at.constant(np.zeros((3,), dtype=config.floatX))
sm = at.vector(shape=(3,))
g = SoftmaxGrad(axis=None)(dy, sm)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(g_fg, [np.ones((3,), dtype=config.floatX)])
@pytest.mark.parametrize(
"x, axis, exc",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论