提交 47c3065c authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1860 from daemonmaker/issue1780

Issue1780
...@@ -1712,6 +1712,7 @@ class Pow(BinaryScalarOp): ...@@ -1712,6 +1712,7 @@ class Pow(BinaryScalarOp):
first_part = gz * y * x ** (y - 1) first_part = gz * y * x ** (y - 1)
second_part = gz * log(x) * x ** y second_part = gz * log(x) * x ** y
second_part = switch(eq(x, 0), 0, second_part)
return (first_part, second_part) return (first_part, second_part)
......
...@@ -922,23 +922,36 @@ _grad_broadcast_pow_normal = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand ...@@ -922,23 +922,36 @@ _grad_broadcast_pow_normal = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand
#complex3 = (rand(2,3),randcomplex(2,3)), #complex3 = (rand(2,3),randcomplex(2,3)),
#empty1 = (numpy.asarray([]), numpy.asarray([1])), #empty1 = (numpy.asarray([]), numpy.asarray([1])),
#empty2 = (numpy.asarray([0]), numpy.asarray([])), #empty2 = (numpy.asarray([0]), numpy.asarray([])),
x_eq_zero = (
numpy.asarray([0.], dtype=config.floatX),
numpy.asarray([2.], dtype=config.floatX)
), # Test for issue 1780
) )
#empty2 case is not supported by numpy. #empty2 case is not supported by numpy.
_good_broadcast_pow_normal_float_pow = copy(_good_broadcast_pow_normal_float) _good_broadcast_pow_normal_float_pow = copy(_good_broadcast_pow_normal_float)
del _good_broadcast_pow_normal_float_pow["empty2"] del _good_broadcast_pow_normal_float_pow["empty2"]
# Disable NAN checking for pow operator per issue #1780
m = copy(theano.compile.get_default_mode())
m.check_isfinite = False
PowTester = makeBroadcastTester( PowTester = makeBroadcastTester(
op=pow, op=pow,
expected=lambda x, y: check_floatX((x, y), x ** y), expected=lambda x, y: check_floatX((x, y), x ** y),
good=_good_broadcast_pow_normal_float, good=_good_broadcast_pow_normal_float,
grad=_grad_broadcast_pow_normal, grad=_grad_broadcast_pow_normal,
name='Pow') name='Pow',
mode=m
)
PowInplaceTester = makeBroadcastTester(op=inplace.pow_inplace, PowInplaceTester = makeBroadcastTester(
op=inplace.pow_inplace,
expected=lambda x, y: x ** y, expected=lambda x, y: x ** y,
good = _good_broadcast_pow_normal_float_pow, good=_good_broadcast_pow_normal_float_pow,
grad = _grad_broadcast_pow_normal, grad=_grad_broadcast_pow_normal,
inplace = True) inplace=True,
mode=m
)
#Those are corner case when rounding. Their is many rounding algo. #Those are corner case when rounding. Their is many rounding algo.
#c round() fct and numpy round are not the same! #c round() fct and numpy round are not the same!
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论