提交 b5cca42e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2408 from ChienliMa/grad_for_round

Grad for round
......@@ -2058,6 +2058,14 @@ class RoundHalfToEven(UnaryScalarOp):
def impl(self, x):
return numpy.round(x)
def grad(self, (x,), (gz,)):
rval = x.zeros_like()
if rval.type.dtype in discrete_types:
rval = rval.astype(theano.config.floatX)
return [rval]
def c_code___(self, node, name, (x, ), (z, ), sub):
typ = node.outputs[0].type.dtype
if not typ in ['float32', 'float64']:
......@@ -2140,9 +2148,16 @@ class RoundHalfAwayFromZero(UnaryScalarOp):
See http://en.wikipedia.org/wiki/Rounding for more detail
"""
def impl(self, x):
return round_half_away_from_zero_vec(x)
def grad(self, (x,), (gz,)):
rval = x.zeros_like()
if rval.type.dtype in discrete_types:
rval = rval.astype(theano.config.floatX)
return [rval]
def c_code(self, node, name, (x, ), (z, ), sub):
if node.outputs[0].type.dtype in ['float32', 'float64']:
return "%(z)s = round(%(x)s);" % locals()
......
......@@ -1099,6 +1099,10 @@ _grad_broadcast_unary_normal = dict(
#empty = [numpy.asarray([])] # XXX: should this be included?
)
_grad_broadcast_unary_normal_no_complex_no_corner_case = copymod(
_grad_broadcast_unary_normal_no_complex,
without=['corner_case'])
_grad_broadcast_unary_abs1_no_complex = dict(
normal=[numpy.asarray(rand_ranged(-1, 1, (2, 3)), dtype=floatX)],
)
......@@ -1216,13 +1220,15 @@ TruncTester = makeBroadcastTester(
RoundHalfToEvenTester = makeBroadcastTester(
op=tensor.round_half_to_even,
expected=numpy.round,
good=_good_broadcast_unary_normal_float_no_complex)
# TODO: Why complex are accepted in the next one?
expected= numpy.round,
good=_good_broadcast_unary_normal_float_no_complex,
grad=_grad_broadcast_unary_normal_no_complex_no_corner_case)
RoundHalfToEvenInplaceTester = makeBroadcastTester(
op=inplace.round_half_to_even_inplace,
expected=numpy.round,
good=_good_broadcast_unary_normal_float,
expected= numpy.round,
good=_good_broadcast_unary_normal_float_no_complex,
grad=_grad_broadcast_unary_normal_no_complex_no_corner_case,
inplace=True)
#numpy.vectorize don't handle correctly empty ndarray.
......@@ -1230,19 +1236,23 @@ RoundHalfToEvenInplaceTester = makeBroadcastTester(
#This happen in float32 mode.
RoundHalfAwayFromZeroTester = makeBroadcastTester(
op=tensor.round_half_away_from_zero,
expected=theano.scalar.basic.round_half_away_from_zero_vec,
good=_good_broadcast_unary_normal_float_no_empty_no_complex)
expected=lambda a:theano.scalar.basic.round_half_away_from_zero_vec(a),
good=_good_broadcast_unary_normal_float_no_empty_no_complex,
grad=_grad_broadcast_unary_normal_no_complex_no_corner_case)
#_good_broadcast_unary_normal_float)
RoundHalfAwayFromZeroInplaceTester = makeBroadcastTester(
op=inplace.round_half_away_from_zero_inplace,
expected=theano.scalar.basic.round_half_away_from_zero_vec,
expected=lambda a:theano.scalar.basic.round_half_away_from_zero_vec(a),
good=_good_broadcast_unary_normal_float_no_empty_no_complex,
grad=_grad_broadcast_unary_normal_no_complex_no_corner_case,
inplace=True)
SqrTester = makeBroadcastTester(op=tensor.sqr,
expected=numpy.square,
good=_good_broadcast_unary_normal,
grad=_grad_broadcast_unary_normal)
SqrInplaceTester = makeBroadcastTester(op=inplace.sqr_inplace,
expected=numpy.square,
good=_good_broadcast_unary_normal,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论