rewrite of pow functions

上级 f4c3ad95
...@@ -423,26 +423,40 @@ class array_copy(omega_op): ...@@ -423,26 +423,40 @@ class array_copy(omega_op):
## Power ## ## Power ##
class proto_pow(omega_op): class proto_pow(omega_op):
def grad(x, y, gz):
pass pass
class pow(proto_pow): class pow_elemwise(proto_pow):
impl = numpy.ndarray.__pow__ impl = assert_same_shapes(numpy.ndarray.__pow__)
def grad(x, s, gz):
return gz * s * (pow_elemwise(x, s-1.0))
class ipow(proto_pow, inplace): class pow_scalar_l(proto_pow):
impl = numpy.ndarray.__ipow__ impl = tensor_scalar_op(numpy.ndarray.__pow__)
def grad(x, s, gz):
return gz * x * (pow_scalar_l(s,x-1.0))
class pow_scalar_r(proto_pow):
impl = tensor_scalar_op(numpy.ndarray.__pow__)
def grad(x, s, gz):
return gz * s * (pow_scalar_r(x,s-1.0))
class proto_pow_elemwise(omega_op): class proto_ipow(omega_op):
def grad(x, y, gz):
pass pass
class pow_elemwise(proto_pow_elemwise): class ipow_elemwise(proto_ipow):
impl = numpy.ndarray.__pow__ def __init__(self, *args, **kwargs):
omega_op.__init__(self, *args, **kwargs)
raise NotImplementedError()
class ipow_elemwise(proto_pow_elemwise, inplace): class ipow_scalar_l(proto_ipow):
impl = numpy.ndarray.__ipow__ def __init__(self, *args, **kwargs):
omega_op.__init__(self, *args, **kwargs)
raise NotImplementedError()
class ipow_scalar_r(proto_ipow):
def __init__(self, *args, **kwargs):
omega_op.__init__(self, *args, **kwargs)
raise NotImplementedError()
## Others ## ## Others ##
...@@ -482,8 +496,8 @@ imul = scalar_switch(imul_elemwise, iscale, iscale) ...@@ -482,8 +496,8 @@ imul = scalar_switch(imul_elemwise, iscale, iscale)
div = scalar_switch(div_elemwise, div_scalar_r, div_scalar_l) div = scalar_switch(div_elemwise, div_scalar_r, div_scalar_l)
idiv = scalar_switch(idiv_elemwise, idiv_scalar_r, idiv_scalar_l) idiv = scalar_switch(idiv_elemwise, idiv_scalar_r, idiv_scalar_l)
# pow = scalar_switch(pow_elemwise, pow_scalar_r, pow_scalar_l) pow = scalar_switch(pow_elemwise, pow_scalar_r, pow_scalar_l)
# ipow = scalar_switch(ipow_elemwise, ipow_scalar_r, ipow_scalar_l) ipow = scalar_switch(ipow_elemwise, ipow_scalar_r, ipow_scalar_l)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论