提交 253b3856 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for tensor/xlogx.py

上级 a4f0dced
......@@ -13,12 +13,15 @@ class XlogX(scalar.UnaryScalarOp):
if x == 0.0:
return 0.0
return x * numpy.log(x)
def impl(self, x):
return XlogX.st_impl(x)
def grad(self, inputs, grads):
x, = inputs
gz, = grads
return [gz * (1 + scalar.log(x))]
def c_code(self, node, name, inputs, outputs, sub):
x, = inputs
z, = outputs
......@@ -28,7 +31,8 @@ class XlogX(scalar.UnaryScalarOp):
? 0.0
: %(x)s * log(%(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
scalar_xlogx = XlogX(scalar.upgrade_to_float, name='scalar_xlogx')
scalar_xlogx = XlogX(scalar.upgrade_to_float, name='scalar_xlogx')
xlogx = Elemwise(scalar_xlogx, name='xlogx')
......@@ -41,12 +45,15 @@ class XlogY0(scalar.BinaryScalarOp):
if x == 0.0:
return 0.0
return x * numpy.log(y)
def impl(self, x, y):
return XlogY0.st_impl(x, y)
def grad(self, inputs, grads):
x, y = inputs
gz, = grads
return [gz * scalar.log(y), gz * x / y]
def c_code(self, node, name, inputs, outputs, sub):
x, y = inputs
z, = outputs
......@@ -56,5 +63,6 @@ class XlogY0(scalar.BinaryScalarOp):
? 0.0
: %(x)s * log(%(y)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
scalar_xlogy0 = XlogY0(scalar.upgrade_to_float, name='scalar_xlogy0')
scalar_xlogy0 = XlogY0(scalar.upgrade_to_float, name='scalar_xlogy0')
xlogy0 = Elemwise(scalar_xlogy0, name='xlogy0')
......@@ -57,7 +57,6 @@ whitelist_flake8 = [
"typed_list/tests/test_type.py",
"typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py",
"tensor/xlogx.py",
"tensor/blas_headers.py",
"tensor/utils.py",
"tensor/type.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论