提交 2f9672b3 authored 作者: James Bergstra's avatar James Bergstra

modified scalar ops to be nondifferentiable for integer arguments

上级 98603703
...@@ -183,6 +183,8 @@ int_types = int8, int16, int32, int64 ...@@ -183,6 +183,8 @@ int_types = int8, int16, int32, int64
float_types = float32, float64 float_types = float32, float64
complex_types = complex64, complex128 complex_types = complex64, complex128
grad_types = float_types + complex_types # these are the types for which gradients can be defined.
class _scalar_py_operators: class _scalar_py_operators:
#UNARY #UNARY
...@@ -463,7 +465,9 @@ class Switch(ScalarOp): ...@@ -463,7 +465,9 @@ class Switch(ScalarOp):
def c_code(self, node, name, (cond, ift, iff), (z, ), sub): def c_code(self, node, name, (cond, ift, iff), (z, ), sub):
return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals() return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals()
def grad(self, (cond, ift, iff), (gz, )): def grad(self, (cond, ift, iff), (gz, )):
return None, switch(cond, gz, 0), switch(cond, 0, gz) return (None,
switch(cond, gz, 0) if ift.type in grad_types else None,
switch(cond, 0, gz) if iff.type in grad_types else None)
def output_types(self, (cond_t, ift_t, iff_t)): def output_types(self, (cond_t, ift_t, iff_t)):
return upcast_out(ift_t, iff_t) return upcast_out(ift_t, iff_t)
switch = Switch() switch = Switch()
...@@ -539,7 +543,7 @@ class Add(ScalarOp): ...@@ -539,7 +543,7 @@ class Add(ScalarOp):
else: else:
return z + " = " + " + ".join(inputs) + ";" return z + " = " + " + ".join(inputs) + ";"
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
return (gz, ) * len(inputs) return [(gz if i.type in grad_types else None) for i in inputs]
add = Add(upcast_out, name = 'add') add = Add(upcast_out, name = 'add')
class Mul(ScalarOp): class Mul(ScalarOp):
...@@ -554,7 +558,8 @@ class Mul(ScalarOp): ...@@ -554,7 +558,8 @@ class Mul(ScalarOp):
else: else:
return z + " = " + " * ".join(inputs) + ";" return z + " = " + " * ".join(inputs) + ";"
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
return [mul(*([gz] + utils.difference(inputs, [input]))) return [(mul(*([gz] + utils.difference(inputs, [input])))
if input.type in grad_types else None)
for input in inputs] for input in inputs]
mul = Mul(upcast_out, name = 'mul') mul = Mul(upcast_out, name = 'mul')
...@@ -564,7 +569,7 @@ class Sub(BinaryScalarOp): ...@@ -564,7 +569,7 @@ class Sub(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" % locals() return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz, -gz return gz if x.type in grad_types else None, -gz if y.type in grad_types else None
sub = Sub(upcast_out, name = 'sub') sub = Sub(upcast_out, name = 'sub')
def div_proxy(x, y): def div_proxy(x, y):
...@@ -593,7 +598,8 @@ class TrueDiv(BinaryScalarOp): ...@@ -593,7 +598,8 @@ class TrueDiv(BinaryScalarOp):
return "%(z)s = ((double)%(x)s) / %(y)s;" % locals() return "%(z)s = ((double)%(x)s) / %(y)s;" % locals()
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz / y, -(gz * x) / (y * y) return (gz / y if x.type in grad_types else None,
-(gz * x) / (y * y) if y.type in grad_types else None)
true_div = TrueDiv(upcast_out, name = 'true_div') true_div = TrueDiv(upcast_out, name = 'true_div')
class IntDiv(BinaryScalarOp): class IntDiv(BinaryScalarOp):
...@@ -621,7 +627,8 @@ class Pow(BinaryScalarOp): ...@@ -621,7 +627,8 @@ class Pow(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" % locals() return "%(z)s = pow(%(x)s, %(y)s);" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz * y * x**(y - 1), gz * log(x) * x**y return (gz * y * x**(y - 1) if x.type in grad_types else None,
gz * log(x) * x**y if y.type in grad_types else None)
pow = Pow(upcast_out, name = 'pow') pow = Pow(upcast_out, name = 'pow')
class Clip(ScalarOp): class Clip(ScalarOp):
...@@ -632,7 +639,7 @@ class Clip(ScalarOp): ...@@ -632,7 +639,7 @@ class Clip(ScalarOp):
return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals() return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals()
def grad(self, (x, min, max), (gz, )): def grad(self, (x, min, max), (gz, )):
gx = ((x > min) & (x < max)) * gz gx = ((x > min) & (x < max)) * gz
return gx, None, None return gx if x.type in grad_types else None, None, None
clip = Clip(transfer_type(0), name = 'clip') clip = Clip(transfer_type(0), name = 'clip')
class First(BinaryScalarOp): class First(BinaryScalarOp):
...@@ -641,7 +648,7 @@ class First(BinaryScalarOp): ...@@ -641,7 +648,7 @@ class First(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz, None return gz if x.type in grad_type else None, None
first = First(transfer_type(0), name = 'first') first = First(transfer_type(0), name = 'first')
class Second(BinaryScalarOp): class Second(BinaryScalarOp):
...@@ -650,7 +657,7 @@ class Second(BinaryScalarOp): ...@@ -650,7 +657,7 @@ class Second(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(y)s;" % locals() return "%(z)s = %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return None, gz return None, gz if y.type in grad_types else None
second = Second(transfer_type(1), name = 'second') second = Second(transfer_type(1), name = 'second')
...@@ -661,7 +668,7 @@ class Identity(UnaryScalarOp): ...@@ -661,7 +668,7 @@ class Identity(UnaryScalarOp):
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz, return gz if x.type in grad_type else None,
identity = Identity(same_out, name = 'identity') identity = Identity(same_out, name = 'identity')
class Abs(UnaryScalarOp): class Abs(UnaryScalarOp):
...@@ -677,7 +684,7 @@ class Abs(UnaryScalarOp): ...@@ -677,7 +684,7 @@ class Abs(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.abs(x) return numpy.abs(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * sgn(x), return gz * sgn(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
type = node.inputs[0].type type = node.inputs[0].type
if type in int_types: if type in int_types:
...@@ -705,8 +712,6 @@ sgn = Sgn(same_out, name = 'sgn') ...@@ -705,8 +712,6 @@ sgn = Sgn(same_out, name = 'sgn')
class IRound(UnaryScalarOp): class IRound(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.asarray(numpy.round(x), dtype = 'int64') return numpy.asarray(numpy.round(x), dtype = 'int64')
def grad(self, (x, ), (gz, )):
return gz * sgn(x),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = round(%(x)s);" % locals() return "%(z)s = round(%(x)s);" % locals()
iround = IRound(int_out) iround = IRound(int_out)
...@@ -715,7 +720,7 @@ class Neg(UnaryScalarOp): ...@@ -715,7 +720,7 @@ class Neg(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return -x return -x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz, return -gz if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = -%(x)s;" % locals() return "%(z)s = -%(x)s;" % locals()
neg = Neg(same_out, name = 'neg') neg = Neg(same_out, name = 'neg')
...@@ -724,7 +729,7 @@ class Inv(UnaryScalarOp): ...@@ -724,7 +729,7 @@ class Inv(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return 1.0 / x return 1.0 / x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz / (x * x), return -gz / (x * x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = 1.0 / %(x)s;" % locals() return "%(z)s = 1.0 / %(x)s;" % locals()
inv = Inv(upgrade_to_float, name = 'inv') inv = Inv(upgrade_to_float, name = 'inv')
...@@ -733,7 +738,7 @@ class Log(UnaryScalarOp): ...@@ -733,7 +738,7 @@ class Log(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.log(x) return math.log(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / x, return gz / x if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
#todo: the version using log2 seems to be very slightly faster #todo: the version using log2 seems to be very slightly faster
# on some machines for some reason, check if it's worth switching # on some machines for some reason, check if it's worth switching
...@@ -745,7 +750,7 @@ class Log2(UnaryScalarOp): ...@@ -745,7 +750,7 @@ class Log2(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.log2(x) return numpy.log2(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / (x * math.log(2.0)), return gz / (x * math.log(2.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = log2(%(x)s);" % locals() return "%(z)s = log2(%(x)s);" % locals()
log2 = Log2(upgrade_to_float, name = 'log2') log2 = Log2(upgrade_to_float, name = 'log2')
...@@ -754,7 +759,7 @@ class Log10(UnaryScalarOp): ...@@ -754,7 +759,7 @@ class Log10(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.log10(x) return numpy.log10(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / (x * math.log(10.0)), return gz / (x * math.log(10.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = log10(%(x)s);" % locals() return "%(z)s = log10(%(x)s);" % locals()
log10 = Log10(upgrade_to_float, name = 'log10') log10 = Log10(upgrade_to_float, name = 'log10')
...@@ -763,7 +768,7 @@ class Exp(UnaryScalarOp): ...@@ -763,7 +768,7 @@ class Exp(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.exp(x) return math.exp(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * exp(x), return gz * exp(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = exp(%(x)s);" % locals() return "%(z)s = exp(%(x)s);" % locals()
exp = Exp(upgrade_to_float, name = 'exp') exp = Exp(upgrade_to_float, name = 'exp')
...@@ -772,7 +777,7 @@ class Sqr(UnaryScalarOp): ...@@ -772,7 +777,7 @@ class Sqr(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return x*x return x*x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * x * 2, return gz * x * 2 if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s * %(x)s;" % locals() return "%(z)s = %(x)s * %(x)s;" % locals()
sqr = Sqr(same_out, name = 'sqr') sqr = Sqr(same_out, name = 'sqr')
...@@ -781,7 +786,7 @@ class Sqrt(UnaryScalarOp): ...@@ -781,7 +786,7 @@ class Sqrt(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sqrt(x) return math.sqrt(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return (gz * 0.5) / sqrt(x), return (gz * 0.5) / sqrt(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sqrt(%(x)s);" % locals() return "%(z)s = sqrt(%(x)s);" % locals()
sqrt = Sqrt(upgrade_to_float, name = 'sqrt') sqrt = Sqrt(upgrade_to_float, name = 'sqrt')
...@@ -790,7 +795,7 @@ class Cos(UnaryScalarOp): ...@@ -790,7 +795,7 @@ class Cos(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.cos(x) return math.cos(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz * sin(x), return -gz * sin(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = cos(%(x)s);" % locals() return "%(z)s = cos(%(x)s);" % locals()
cos = Cos(upgrade_to_float, name = 'cos') cos = Cos(upgrade_to_float, name = 'cos')
...@@ -799,7 +804,7 @@ class Sin(UnaryScalarOp): ...@@ -799,7 +804,7 @@ class Sin(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sin(x) return math.sin(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * cos(x), return gz * cos(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals() return "%(z)s = sin(%(x)s);" % locals()
sin = Sin(upgrade_to_float, name = 'sin') sin = Sin(upgrade_to_float, name = 'sin')
...@@ -808,7 +813,7 @@ class Tan(UnaryScalarOp): ...@@ -808,7 +813,7 @@ class Tan(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.tan(x) return math.tan(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / sqr(cos(x)), return gz / sqr(cos(x)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = tan(%(x)s);" % locals() return "%(z)s = tan(%(x)s);" % locals()
tan = Tan(upgrade_to_float, name = 'tan') tan = Tan(upgrade_to_float, name = 'tan')
...@@ -820,7 +825,7 @@ class Cosh(UnaryScalarOp): ...@@ -820,7 +825,7 @@ class Cosh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.cosh(x) return math.cosh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * sinh(x), return gz * sinh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = cosh(%(x)s);" % locals() return "%(z)s = cosh(%(x)s);" % locals()
cosh = Cosh(upgrade_to_float, name = 'cosh') cosh = Cosh(upgrade_to_float, name = 'cosh')
...@@ -832,7 +837,7 @@ class Sinh(UnaryScalarOp): ...@@ -832,7 +837,7 @@ class Sinh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sinh(x) return math.sinh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * cosh(x), return gz * cosh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sinh(%(x)s);" % locals() return "%(z)s = sinh(%(x)s);" % locals()
sinh = Sinh(upgrade_to_float, name = 'sinh') sinh = Sinh(upgrade_to_float, name = 'sinh')
...@@ -845,7 +850,7 @@ class Tanh(UnaryScalarOp): ...@@ -845,7 +850,7 @@ class Tanh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.tanh(x) return math.tanh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * (1 - sqr(tanh(x))), return gz * (1 - sqr(tanh(x))) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = tanh(%(x)s);" % locals() return "%(z)s = tanh(%(x)s);" % locals()
tanh = Tanh(upgrade_to_float, name = 'tanh') tanh = Tanh(upgrade_to_float, name = 'tanh')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论