提交 0517a8f3 authored 作者: John Salvatier's avatar John Salvatier

gave binary and unary bit ops proper gradients instead of none (caused failure)

上级 3b076f98
...@@ -1090,7 +1090,7 @@ class UnaryBitOp(UnaryScalarOp): ...@@ -1090,7 +1090,7 @@ class UnaryBitOp(UnaryScalarOp):
return upcast_out(*input_types[0]) return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
return [None] return [inputs[0].zeros_like().astype(theano.config.floatX)]
class BinaryBitOp(BinaryScalarOp): class BinaryBitOp(BinaryScalarOp):
...@@ -1103,7 +1103,8 @@ class BinaryBitOp(BinaryScalarOp): ...@@ -1103,7 +1103,8 @@ class BinaryBitOp(BinaryScalarOp):
return upcast_out(*input_types[0]) return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
return [None, None] a,b = inputs
return [a.zeros_like().astype(theano.config.floatX), b.zeros_like().astype(theano.config.floatX)]
class OR(BinaryBitOp): class OR(BinaryBitOp):
......
...@@ -743,8 +743,6 @@ class Elemwise(Op): ...@@ -743,8 +743,6 @@ class Elemwise(Op):
scalar_inputs = map(as_scalar, inputs) scalar_inputs = map(as_scalar, inputs)
scalar_ograds = map(as_scalar, ograds) scalar_ograds = map(as_scalar, ograds)
scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds) scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds)
for igrad in scalar_igrads:
assert igrad is not None
finally: finally:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论