提交 10901b6a authored 作者: carriepl's avatar carriepl

Merge pull request #3035 from thomasmesnard/Inrangegrad

Add missing test.
...@@ -21,7 +21,7 @@ from theano.scalar.basic import (floats, float32, float64, ...@@ -21,7 +21,7 @@ from theano.scalar.basic import (floats, float32, float64,
ints, int8, int32, complex64, ints, int8, int32, complex64,
ComplexError, IntDiv, TrueDiv, ComplexError, IntDiv, TrueDiv,
Composite, add, div_proxy, Composite, add, div_proxy,
and_, eq, neq, invert, mul, Scalar) and_, eq, neq, invert, mul, Scalar, InRange)
from theano.scalar.basic import ( from theano.scalar.basic import (
true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad, true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad,
rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh, rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh,
...@@ -413,6 +413,32 @@ def test_grad_identity(): ...@@ -413,6 +413,32 @@ def test_grad_identity():
theano.gradient.grad(l, x) theano.gradient.grad(l, x)
def test_grad_inrange():
for bound_definition in [(True, True), (False, False)]:
# Instantiate op, and then take the gradient
op = InRange(*bound_definition)
x = theano.tensor.fscalar('x')
low = theano.tensor.fscalar('low')
high = theano.tensor.fscalar('high')
out = op(x, low, high)
gx, glow, ghigh = theano.tensor.grad(out, [x, low, high])
# We look if the gradient are equal to zero
# if x is lower than the lower bound,
# equal to the lower bound, between lower and higher bound,
# equal to the higher bound and higher than the higher
# bound.
# Mathematically we should have an infinite gradient when
# x is equal to the lower or higher bound but in that case
# Theano defines the gradient to be zero for stability.
f = theano.function([x, low, high], [gx, glow, ghigh])
utt.assert_allclose(f(0, 1, 5), [0, 0, 0])
utt.assert_allclose(f(1, 1, 5), [0, 0, 0])
utt.assert_allclose(f(2, 1, 5), [0, 0, 0])
utt.assert_allclose(f(5, 1, 5), [0, 0, 0])
utt.assert_allclose(f(7, 1, 5), [0, 0, 0])
# Testing of Composite is done in tensor/tests/test_opt.py # Testing of Composite is done in tensor/tests/test_opt.py
# in test_fusion, TestCompositeCodegen # in test_fusion, TestCompositeCodegen
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论