提交 04cf990b authored 作者: affanv14's avatar affanv14

changed grad to L_op for remaining functions

上级 e57b752e
...@@ -1194,10 +1194,9 @@ class LogicalComparison(BinaryScalarOp): ...@@ -1194,10 +1194,9 @@ class LogicalComparison(BinaryScalarOp):
def output_types(self, *input_dtypes): def output_types(self, *input_dtypes):
return [bool] if getattr(self, 'bool', False) else [int8] return [bool] if getattr(self, 'bool', False) else [int8]
def grad(self, inputs, output_gradients): def L_op(self, inputs, outputs, output_gradients):
x, y = inputs x, y = inputs
out = self(x, y) assert outputs[0].type == bool
assert out.type == bool
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)] y.zeros_like().astype(theano.config.floatX)]
...@@ -1436,7 +1435,7 @@ class InRange(LogicalComparison): ...@@ -1436,7 +1435,7 @@ class InRange(LogicalComparison):
else: else:
return elem.zeros_like() return elem.zeros_like()
def grad(self, inputs, gout): def L_op(self, inputs, outputs, gout):
(x, low, hi) = inputs (x, low, hi) = inputs
(gz,) = gout (gz,) = gout
grads = [] grads = []
...@@ -1605,7 +1604,7 @@ class Maximum(BinaryScalarOp): ...@@ -1605,7 +1604,7 @@ class Maximum(BinaryScalarOp):
return ('%(z)s = ((%(y)s)>(%(x)s)? (%(y)s): ' return ('%(z)s = ((%(y)s)>(%(x)s)? (%(y)s): '
'((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals()) '((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, inputs, gout): def L_op(self, inputs, outputs, gout):
(x, y) = inputs (x, y) = inputs
(gz,) = gout (gz,) = gout
if gz.type in complex_types: if gz.type in complex_types:
...@@ -1613,14 +1612,12 @@ class Maximum(BinaryScalarOp): ...@@ -1613,14 +1612,12 @@ class Maximum(BinaryScalarOp):
# but the gradient for complex is not. # but the gradient for complex is not.
raise NotImplementedError() raise NotImplementedError()
output = self(x, y) if outputs[0].type in discrete_types:
if output.type in discrete_types:
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)] y.zeros_like().astype(theano.config.floatX)]
gx = eq(output, x) * gz gx = eq(outputs[0], x) * gz
gy = eq(output, y) * gz gy = eq(outputs[0], y) * gz
return (gx, gy) return (gx, gy)
maximum = Maximum(upcast_out, name='maximum') maximum = Maximum(upcast_out, name='maximum')
...@@ -1643,7 +1640,7 @@ class Minimum(BinaryScalarOp): ...@@ -1643,7 +1640,7 @@ class Minimum(BinaryScalarOp):
return ('%(z)s = ((%(y)s)<(%(x)s)? (%(y)s): ' return ('%(z)s = ((%(y)s)<(%(x)s)? (%(y)s): '
'((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals()) '((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, inputs, gout): def L_op(self, inputs, outputs, gout):
(x, y) = inputs (x, y) = inputs
(gz,) = gout (gz,) = gout
if gz.type in complex_types: if gz.type in complex_types:
...@@ -1651,12 +1648,11 @@ class Minimum(BinaryScalarOp): ...@@ -1651,12 +1648,11 @@ class Minimum(BinaryScalarOp):
# but the gradient for complex is not. # but the gradient for complex is not.
raise NotImplementedError() raise NotImplementedError()
output = minimum(x, y) if outputs[0].type in discrete_types:
if output.type in discrete_types:
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)] y.zeros_like().astype(theano.config.floatX)]
gx = eq(output, x) * gz gx = eq(outputs[0], x) * gz
gy = eq(output, y) * gz gy = eq(outputs[0], y) * gz
return (gx, gy) return (gx, gy)
minimum = Minimum(upcast_out, name='minimum') minimum = Minimum(upcast_out, name='minimum')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论