提交 a7f6171c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2225 from caglar/fix_grad_clip

Fixes #633
...@@ -1792,7 +1792,7 @@ class Clip(ScalarOp): ...@@ -1792,7 +1792,7 @@ class Clip(ScalarOp):
def grad(self, (x, mn, mx), (gz, )): def grad(self, (x, mn, mx), (gz, )):
assert gz.type not in complex_types assert gz.type not in complex_types
gx = ((x > mn) & (x < mx)) * gz gx = ((x >= mn) & (x <= mx)) * gz
gmn = (x < mn) * gz gmn = (x < mn) * gz
gmx = (x > mx) * gz gmx = (x > mx) * gz
......
...@@ -14,13 +14,14 @@ import unittest ...@@ -14,13 +14,14 @@ import unittest
import theano import theano
from theano.gof import FunctionGraph from theano.gof import FunctionGraph
from theano import gof from theano import gof
from theano.tests import unittest_tools as utt
from theano.scalar.basic import (floats, float32, float64, from theano.scalar.basic import (floats, float32, float64,
ints, int8, int32, complex64, ints, int8, int32, complex64,
ComplexError, IntDiv, TrueDiv, ComplexError, IntDiv, TrueDiv,
Composite, add, div_proxy, Composite, add, div_proxy, clip,
and_, eq, neq, invert, mul) and_, eq, neq, invert, mul)
import numpy
def inputs(): def inputs():
return floats('xyz') return floats('xyz')
...@@ -56,6 +57,43 @@ class test_ScalarOps(unittest.TestCase): ...@@ -56,6 +57,43 @@ class test_ScalarOps(unittest.TestCase):
): ):
self.assertTrue(fn(a,b) == a%b, (a,)) self.assertTrue(fn(a,b) == a%b, (a,))
def test_clip_grad(self):
#This is testing for the issue #633
x, y = floats('xy')
a = theano.tensor.clip(x, y, x)
g = theano.gradient.grad(a, x)
fn = gof.DualLinker().accept(FunctionGraph([x, y], [g])).make_function()
# Test the other way around as well
a2 = theano.tensor.clip(x, x, y)
g2 = theano.gradient.grad(a2, x)
fn2 = gof.DualLinker().accept(FunctionGraph([x, y], [g2])).make_function()
# Test for the equal case too .
a3 = theano.tensor.clip(x, x, x)
g3 = theano.gradient.grad(a3, x)
fn3 = gof.DualLinker().accept(FunctionGraph([x], [g3])).make_function()
rng = numpy.random.RandomState(utt.fetch_seed())
ntests = 50
for i in xrange(ntests):
xval = rng.rand(1)
#To ensure that the min < x .
yval_mn = rng.rand(1) - 1.0
#To ensure that the max > x.
yval_mx = rng.rand(1) + 1.0
aval = fn(xval, yval_mn)
aval2 = fn2(xval, yval_mx)
aval3 = fn3(xval)
self.assertTrue(aval == 1.)
self.assertTrue(aval2 == 1.)
self.assertTrue(aval3 == 1.)
class test_composite(unittest.TestCase): class test_composite(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论