提交 5dfbb799 authored 作者: james@mackie's avatar james@mackie

better impl of inv.grad

上级 3176c711
...@@ -614,7 +614,8 @@ class InvElemwise(_Elemwise): ...@@ -614,7 +614,8 @@ class InvElemwise(_Elemwise):
def impl(self, x): def impl(self, x):
return 1.0/x return 1.0/x
def grad(self, x, gz): def grad(self, x, gz):
return -gz / (x*x) ix = inv(x)
return -gz * (ix * ix)
def c_foreach(self, (x_i, ), (z_i, )): def c_foreach(self, (x_i, ), (z_i, )):
return "%(z)s_i = 1.0 / %(x)s_i;" #TODO: cast 1.0 to the dtype of x return "%(z)s_i = 1.0 / %(x)s_i;" #TODO: cast 1.0 to the dtype of x
inv_elemwise = _constructor(InvElemwise) inv_elemwise = _constructor(InvElemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论