提交 0699b48d authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Trust input in test_math_scipy benchmark tests

上级 e299023b
...@@ -431,11 +431,13 @@ def test_gammaincc_ddk_performance(benchmark): ...@@ -431,11 +431,13 @@ def test_gammaincc_ddk_performance(benchmark):
x = vector("x") x = vector("x")
out = gammaincc(k, x) out = gammaincc(k, x)
grad_fn = function([k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN") grad_fn = function(
[k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN", trust_input=True
)
vals = [ vals = [
# Values that hit the second branch of the gradient # Values that hit the second branch of the gradient
np.full((1000,), 3.2), np.full((1000,), 3.2, dtype=k.dtype),
np.full((1000,), 0.01), np.full((1000,), 0.01, dtype=x.dtype),
] ]
verify_grad(gammaincc, vals, rng=rng) verify_grad(gammaincc, vals, rng=rng)
...@@ -1127,9 +1129,13 @@ class TestHyp2F1Grad: ...@@ -1127,9 +1129,13 @@ class TestHyp2F1Grad:
a1, a2, b1, z = pt.scalars("a1", "a2", "b1", "z") a1, a2, b1, z = pt.scalars("a1", "a2", "b1", "z")
hyp2f1_out = pt.hyp2f1(a1, a2, b1, z) hyp2f1_out = pt.hyp2f1(a1, a2, b1, z)
hyp2f1_grad = pt.grad(hyp2f1_out, wrt=a1 if wrt == "a" else [a1, a2, b1, z]) hyp2f1_grad = pt.grad(hyp2f1_out, wrt=a1 if wrt == "a" else [a1, a2, b1, z])
f_grad = function([a1, a2, b1, z], hyp2f1_grad) f_grad = function([a1, a2, b1, z], hyp2f1_grad, trust_input=True)
(test_a1, test_a2, test_b1, test_z, *expected_dds) = case (test_a1, test_a2, test_b1, test_z, *expected_dds) = case
test_a1 = np.array(test_a1, dtype=a1.dtype)
test_a2 = np.array(test_a2, dtype=a2.dtype)
test_b1 = np.array(test_b1, dtype=b1.dtype)
test_z = np.array(test_z, dtype=z.dtype)
result = benchmark(f_grad, test_a1, test_a2, test_b1, test_z) result = benchmark(f_grad, test_a1, test_a2, test_b1, test_z)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论