提交 253b3856 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for tensor/xlogx.py

上级 a4f0dced
...@@ -13,12 +13,15 @@ class XlogX(scalar.UnaryScalarOp): ...@@ -13,12 +13,15 @@ class XlogX(scalar.UnaryScalarOp):
if x == 0.0: if x == 0.0:
return 0.0 return 0.0
return x * numpy.log(x) return x * numpy.log(x)
def impl(self, x): def impl(self, x):
return XlogX.st_impl(x) return XlogX.st_impl(x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
x, = inputs x, = inputs
gz, = grads gz, = grads
return [gz * (1 + scalar.log(x))] return [gz * (1 + scalar.log(x))]
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
x, = inputs x, = inputs
z, = outputs z, = outputs
...@@ -28,6 +31,7 @@ class XlogX(scalar.UnaryScalarOp): ...@@ -28,6 +31,7 @@ class XlogX(scalar.UnaryScalarOp):
? 0.0 ? 0.0
: %(x)s * log(%(x)s);""" % locals() : %(x)s * log(%(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floatingpoint is implemented')
scalar_xlogx = XlogX(scalar.upgrade_to_float, name='scalar_xlogx') scalar_xlogx = XlogX(scalar.upgrade_to_float, name='scalar_xlogx')
xlogx = Elemwise(scalar_xlogx, name='xlogx') xlogx = Elemwise(scalar_xlogx, name='xlogx')
...@@ -41,12 +45,15 @@ class XlogY0(scalar.BinaryScalarOp): ...@@ -41,12 +45,15 @@ class XlogY0(scalar.BinaryScalarOp):
if x == 0.0: if x == 0.0:
return 0.0 return 0.0
return x * numpy.log(y) return x * numpy.log(y)
def impl(self, x, y): def impl(self, x, y):
return XlogY0.st_impl(x, y) return XlogY0.st_impl(x, y)
def grad(self, inputs, grads): def grad(self, inputs, grads):
x, y = inputs x, y = inputs
gz, = grads gz, = grads
return [gz * scalar.log(y), gz * x / y] return [gz * scalar.log(y), gz * x / y]
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
x, y = inputs x, y = inputs
z, = outputs z, = outputs
...@@ -56,5 +63,6 @@ class XlogY0(scalar.BinaryScalarOp): ...@@ -56,5 +63,6 @@ class XlogY0(scalar.BinaryScalarOp):
? 0.0 ? 0.0
: %(x)s * log(%(y)s);""" % locals() : %(x)s * log(%(y)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floatingpoint is implemented')
scalar_xlogy0 = XlogY0(scalar.upgrade_to_float, name='scalar_xlogy0') scalar_xlogy0 = XlogY0(scalar.upgrade_to_float, name='scalar_xlogy0')
xlogy0 = Elemwise(scalar_xlogy0, name='xlogy0') xlogy0 = Elemwise(scalar_xlogy0, name='xlogy0')
...@@ -57,7 +57,6 @@ whitelist_flake8 = [ ...@@ -57,7 +57,6 @@ whitelist_flake8 = [
"typed_list/tests/test_type.py", "typed_list/tests/test_type.py",
"typed_list/tests/test_opt.py", "typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py", "typed_list/tests/test_basic.py",
"tensor/xlogx.py",
"tensor/blas_headers.py", "tensor/blas_headers.py",
"tensor/utils.py", "tensor/utils.py",
"tensor/type.py", "tensor/type.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论