提交 9ad68dfa authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make Clip return a Numba scalar

上级 6570860a
...@@ -1183,9 +1183,15 @@ def numba_funcify_Clip(op, **kwargs): ...@@ -1183,9 +1183,15 @@ def numba_funcify_Clip(op, **kwargs):
@numba.njit @numba.njit
def clip(_x, _min, _max): def clip(_x, _min, _max):
x = to_scalar(_x) x = to_scalar(_x)
min = to_scalar(_min) _min_scalar = to_scalar(_min)
max = to_scalar(_max) _max_scalar = to_scalar(_max)
return np.where(x < min, min, to_scalar(np.where(x > max, max, x)))
if x < _min_scalar:
return _min_scalar
elif x > _max_scalar:
return _max_scalar
else:
return x
return clip return clip
......
...@@ -917,6 +917,17 @@ def test_Clip(v, min, max): ...@@ -917,6 +917,17 @@ def test_Clip(v, min, max):
) )
def test_scalar_Elemwise_Clip():
a = aet.scalar("a")
b = aet.scalar("b")
z = aet.switch(1, a, b)
c = aet.clip(z, 1, 3)
c_fg = FunctionGraph(outputs=[c])
compare_numba_and_py(c_fg, [1, 1])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"vals, dtype", "vals, dtype",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论