提交 2df16b66 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed grad for Minimum, Sub, Trunc, Sum, and ConvTransp3D

上级 e87e11cf
...@@ -1124,11 +1124,13 @@ class Minimum(BinaryScalarOp): ...@@ -1124,11 +1124,13 @@ class Minimum(BinaryScalarOp):
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types assert gz.type not in complex_types
# max is not defined for complex_types # max is not defined for complex_types
gx, gy = None, None
if x.type in float_types: output = minimum(x,y)
gx = cast(eq(minimum(x, y), x) * gz, x.type.dtype) if output.type in discrete_types:
if y.type in float_types: return [x.zeros_like().astype(theano.config.floatX),
gy = cast(eq(minimum(x, y), y) * gz, y.type.dtype) y.zeros_like().astype(theano.config.floatX)]
gx = eq(output, x) * gz
gy = eq(output, y) * gz
return (gx, gy) return (gx, gy)
minimum = Minimum(upcast_out, name='minimum') minimum = Minimum(upcast_out, name='minimum')
...@@ -1225,15 +1227,13 @@ class Sub(BinaryScalarOp): ...@@ -1225,15 +1227,13 @@ class Sub(BinaryScalarOp):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if (x-y).type in discrete_types:
first_part = cast(gz, x.type.dtype) return [x.zeros_like().astype(theano.config.floatX),
else: y.zeros_like().astype(theano.config.floatX)]
first_part = None
first_part = gz
second_part = -gz
if y.type in float_types:
second_part = cast(-gz, y.type.dtype)
else:
second_part = None
return first_part, second_part return first_part, second_part
sub = Sub(upcast_out, name='sub') sub = Sub(upcast_out, name='sub')
...@@ -1767,7 +1767,7 @@ class Trunc(UnaryScalarOp): ...@@ -1767,7 +1767,7 @@ class Trunc(UnaryScalarOp):
return numpy.trunc(x) return numpy.trunc(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return None, return x.zeros_like().astype(theano.config.floatX)
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = %(x)s >= 0? floor(%(x)s): -floor(-%(x)s);" % locals() return "%(z)s = %(x)s >= 0? floor(%(x)s): -floor(-%(x)s);" % locals()
......
...@@ -1612,7 +1612,7 @@ class Sum(CAReduceDtype): ...@@ -1612,7 +1612,7 @@ class Sum(CAReduceDtype):
new_dims.append(i) new_dims.append(i)
i += 1 i += 1
ds_op = DimShuffle(gz.type.broadcastable, new_dims) ds_op = DimShuffle(gz.type.broadcastable, new_dims)
gx = Elemwise(scalar.second)(x, ds_op(gz).astype(x.dtype)) gx = Elemwise(scalar.second)(x, ds_op(gz))
return [gx] return [gx]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
...@@ -2,6 +2,8 @@ import numpy as N ...@@ -2,6 +2,8 @@ import numpy as N
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.misc import strutil from theano.misc import strutil
import theano import theano
from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType
class ConvTransp3D(theano.Op): class ConvTransp3D(theano.Op):
""" "Transpose" of Conv3D (Conv3D implements multiplication by an implicitly defined matrix W. This implements multiplication by its transpose) """ """ "Transpose" of Conv3D (Conv3D implements multiplication by an implicitly defined matrix W. This implements multiplication by its transpose) """
...@@ -42,6 +44,9 @@ class ConvTransp3D(theano.Op): ...@@ -42,6 +44,9 @@ class ConvTransp3D(theano.Op):
W_shape, b_shape, d_shape, H_shape, RShape_shape = input_shapes W_shape, b_shape, d_shape, H_shape, RShape_shape = input_shapes
return [(H_shape[0], RShape[0], RShape[1], RShape[2], W_shape[4])] return [(H_shape[0], RShape[0], RShape[1], RShape[2], W_shape[4])]
def connection_pattern(self, node):
return [[True], [True], [True], [True], [False]]
def grad(self,inputs, output_gradients): def grad(self,inputs, output_gradients):
W,b,d,H, RShape = inputs W,b,d,H, RShape = inputs
dCdR ,= output_gradients dCdR ,= output_gradients
...@@ -49,8 +54,10 @@ class ConvTransp3D(theano.Op): ...@@ -49,8 +54,10 @@ class ConvTransp3D(theano.Op):
WShape = W.shape WShape = W.shape
dCdW = convGrad3D(dCdR,d,WShape,H) dCdW = convGrad3D(dCdR,d,WShape,H)
dCdb = T.sum(dCdR,axis=(0,1,2,3)) dCdb = T.sum(dCdR,axis=(0,1,2,3))
dCdd = None #not differentiable, since d is not continuous # not differentiable, since d affects the output elements
dCdRShape = None #not differentiable, since RShape is not continuous dCdd = grad_undefined(self,2,d)
# disconnected, since RShape just determines the output shape
dCdRShape = DisconnectedType()()
if 'name' in dir(dCdR) and dCdR.name is not None: if 'name' in dir(dCdR) and dCdR.name is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论