提交 2f9672b3 authored 作者: James Bergstra's avatar James Bergstra

modified scalar ops to be nondifferentiable for integer arguments

上级 98603703
......@@ -183,6 +183,8 @@ int_types = int8, int16, int32, int64
float_types = float32, float64
complex_types = complex64, complex128
grad_types = float_types + complex_types # these are the types for which gradients can be defined.
class _scalar_py_operators:
#UNARY
......@@ -463,7 +465,9 @@ class Switch(ScalarOp):
def c_code(self, node, name, (cond, ift, iff), (z, ), sub):
return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals()
def grad(self, (cond, ift, iff), (gz, )):
return None, switch(cond, gz, 0), switch(cond, 0, gz)
return (None,
switch(cond, gz, 0) if ift.type in grad_types else None,
switch(cond, 0, gz) if iff.type in grad_types else None)
def output_types(self, (cond_t, ift_t, iff_t)):
return upcast_out(ift_t, iff_t)
switch = Switch()
......@@ -539,7 +543,7 @@ class Add(ScalarOp):
else:
return z + " = " + " + ".join(inputs) + ";"
def grad(self, inputs, (gz, )):
return (gz, ) * len(inputs)
return [(gz if i.type in grad_types else None) for i in inputs]
add = Add(upcast_out, name = 'add')
class Mul(ScalarOp):
......@@ -554,7 +558,8 @@ class Mul(ScalarOp):
else:
return z + " = " + " * ".join(inputs) + ";"
def grad(self, inputs, (gz, )):
return [mul(*([gz] + utils.difference(inputs, [input])))
return [(mul(*([gz] + utils.difference(inputs, [input])))
if input.type in grad_types else None)
for input in inputs]
mul = Mul(upcast_out, name = 'mul')
......@@ -564,7 +569,7 @@ class Sub(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return gz, -gz
return gz if x.type in grad_types else None, -gz if y.type in grad_types else None
sub = Sub(upcast_out, name = 'sub')
def div_proxy(x, y):
......@@ -593,7 +598,8 @@ class TrueDiv(BinaryScalarOp):
return "%(z)s = ((double)%(x)s) / %(y)s;" % locals()
return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return gz / y, -(gz * x) / (y * y)
return (gz / y if x.type in grad_types else None,
-(gz * x) / (y * y) if y.type in grad_types else None)
true_div = TrueDiv(upcast_out, name = 'true_div')
class IntDiv(BinaryScalarOp):
......@@ -621,7 +627,8 @@ class Pow(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" % locals()
def grad(self, (x, y), (gz, )):
return gz * y * x**(y - 1), gz * log(x) * x**y
return (gz * y * x**(y - 1) if x.type in grad_types else None,
gz * log(x) * x**y if y.type in grad_types else None)
pow = Pow(upcast_out, name = 'pow')
class Clip(ScalarOp):
......@@ -632,7 +639,7 @@ class Clip(ScalarOp):
return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals()
def grad(self, (x, min, max), (gz, )):
gx = ((x > min) & (x < max)) * gz
return gx, None, None
return gx if x.type in grad_types else None, None, None
clip = Clip(transfer_type(0), name = 'clip')
class First(BinaryScalarOp):
......@@ -641,7 +648,7 @@ class First(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s;" % locals()
def grad(self, (x, y), (gz, )):
return gz, None
return gz if x.type in grad_type else None, None
first = First(transfer_type(0), name = 'first')
class Second(BinaryScalarOp):
......@@ -650,7 +657,7 @@ class Second(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return None, gz
return None, gz if y.type in grad_types else None
second = Second(transfer_type(1), name = 'second')
......@@ -661,7 +668,7 @@ class Identity(UnaryScalarOp):
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals()
def grad(self, (x, ), (gz, )):
return gz,
return gz if x.type in grad_type else None,
identity = Identity(same_out, name = 'identity')
class Abs(UnaryScalarOp):
......@@ -677,7 +684,7 @@ class Abs(UnaryScalarOp):
def impl(self, x):
return numpy.abs(x)
def grad(self, (x, ), (gz, )):
return gz * sgn(x),
return gz * sgn(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
type = node.inputs[0].type
if type in int_types:
......@@ -705,8 +712,6 @@ sgn = Sgn(same_out, name = 'sgn')
class IRound(UnaryScalarOp):
def impl(self, x):
return numpy.asarray(numpy.round(x), dtype = 'int64')
def grad(self, (x, ), (gz, )):
return gz * sgn(x),
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = round(%(x)s);" % locals()
iround = IRound(int_out)
......@@ -715,7 +720,7 @@ class Neg(UnaryScalarOp):
def impl(self, x):
return -x
def grad(self, (x, ), (gz, )):
return -gz,
return -gz if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = -%(x)s;" % locals()
neg = Neg(same_out, name = 'neg')
......@@ -724,7 +729,7 @@ class Inv(UnaryScalarOp):
def impl(self, x):
return 1.0 / x
def grad(self, (x, ), (gz, )):
return -gz / (x * x),
return -gz / (x * x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = 1.0 / %(x)s;" % locals()
inv = Inv(upgrade_to_float, name = 'inv')
......@@ -733,7 +738,7 @@ class Log(UnaryScalarOp):
def impl(self, x):
return math.log(x)
def grad(self, (x, ), (gz, )):
return gz / x,
return gz / x if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
#todo: the version using log2 seems to be very slightly faster
# on some machines for some reason, check if it's worth switching
......@@ -745,7 +750,7 @@ class Log2(UnaryScalarOp):
def impl(self, x):
return numpy.log2(x)
def grad(self, (x, ), (gz, )):
return gz / (x * math.log(2.0)),
return gz / (x * math.log(2.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = log2(%(x)s);" % locals()
log2 = Log2(upgrade_to_float, name = 'log2')
......@@ -754,7 +759,7 @@ class Log10(UnaryScalarOp):
def impl(self, x):
return numpy.log10(x)
def grad(self, (x, ), (gz, )):
return gz / (x * math.log(10.0)),
return gz / (x * math.log(10.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = log10(%(x)s);" % locals()
log10 = Log10(upgrade_to_float, name = 'log10')
......@@ -763,7 +768,7 @@ class Exp(UnaryScalarOp):
def impl(self, x):
return math.exp(x)
def grad(self, (x, ), (gz, )):
return gz * exp(x),
return gz * exp(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = exp(%(x)s);" % locals()
exp = Exp(upgrade_to_float, name = 'exp')
......@@ -772,7 +777,7 @@ class Sqr(UnaryScalarOp):
def impl(self, x):
return x*x
def grad(self, (x, ), (gz, )):
return gz * x * 2,
return gz * x * 2 if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s * %(x)s;" % locals()
sqr = Sqr(same_out, name = 'sqr')
......@@ -781,7 +786,7 @@ class Sqrt(UnaryScalarOp):
def impl(self, x):
return math.sqrt(x)
def grad(self, (x, ), (gz, )):
return (gz * 0.5) / sqrt(x),
return (gz * 0.5) / sqrt(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sqrt(%(x)s);" % locals()
sqrt = Sqrt(upgrade_to_float, name = 'sqrt')
......@@ -790,7 +795,7 @@ class Cos(UnaryScalarOp):
def impl(self, x):
return math.cos(x)
def grad(self, (x, ), (gz, )):
return -gz * sin(x),
return -gz * sin(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = cos(%(x)s);" % locals()
cos = Cos(upgrade_to_float, name = 'cos')
......@@ -799,7 +804,7 @@ class Sin(UnaryScalarOp):
def impl(self, x):
return math.sin(x)
def grad(self, (x, ), (gz, )):
return gz * cos(x),
return gz * cos(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals()
sin = Sin(upgrade_to_float, name = 'sin')
......@@ -808,7 +813,7 @@ class Tan(UnaryScalarOp):
def impl(self, x):
return math.tan(x)
def grad(self, (x, ), (gz, )):
return gz / sqr(cos(x)),
return gz / sqr(cos(x)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = tan(%(x)s);" % locals()
tan = Tan(upgrade_to_float, name = 'tan')
......@@ -820,7 +825,7 @@ class Cosh(UnaryScalarOp):
def impl(self, x):
return math.cosh(x)
def grad(self, (x, ), (gz, )):
return gz * sinh(x),
return gz * sinh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = cosh(%(x)s);" % locals()
cosh = Cosh(upgrade_to_float, name = 'cosh')
......@@ -832,7 +837,7 @@ class Sinh(UnaryScalarOp):
def impl(self, x):
return math.sinh(x)
def grad(self, (x, ), (gz, )):
return gz * cosh(x),
return gz * cosh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sinh(%(x)s);" % locals()
sinh = Sinh(upgrade_to_float, name = 'sinh')
......@@ -845,7 +850,7 @@ class Tanh(UnaryScalarOp):
def impl(self, x):
return math.tanh(x)
def grad(self, (x, ), (gz, )):
return gz * (1 - sqr(tanh(x))),
return gz * (1 - sqr(tanh(x))) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = tanh(%(x)s);" % locals()
tanh = Tanh(upgrade_to_float, name = 'tanh')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论