added commented dtype_specs, added loop check to T_abs

上级 d26f29d3
...@@ -148,8 +148,9 @@ class T_abs(unittest.TestCase): ...@@ -148,8 +148,9 @@ class T_abs(unittest.TestCase):
check_eq(self, t, abs(t), 1.0, 1.0) check_eq(self, t, abs(t), 1.0, 1.0)
check_eq(self, t, abs(t), -1.0, 1.0) check_eq(self, t, abs(t), -1.0, 1.0)
t = tensor([0.0, 0.0]) for shape in (2,), (3,4):
d = numpy.asarray([-0.4, 1.2]) t = tensor(numpy.ones(shape))
d = numpy.random.rand(*shape)*2-1.0
check_eq(self, t, abs(t), d, abs(d)) check_eq(self, t, abs(t), d, abs(d))
check_eq(self, t, abs(t), -d, abs(-d)) check_eq(self, t, abs(t), -d, abs(-d))
...@@ -160,7 +161,7 @@ class T_abs(unittest.TestCase): ...@@ -160,7 +161,7 @@ class T_abs(unittest.TestCase):
def impl(self, x): def impl(self, x):
return numpy.abs(x) return numpy.abs(x)
def grad(self, x, gz): def grad(self, x, gz):
return -gz * sgn(x) return scale(gz * sgn(x),0.9)
def c_foreach(self, (x_i, ), (z_i, )): def c_foreach(self, (x_i, ), (z_i, )):
return "z_i = abs(x_i);" return "z_i = abs(x_i);"
......
...@@ -70,6 +70,13 @@ class Tensor(ResultBase): ...@@ -70,6 +70,13 @@ class Tensor(ResultBase):
# type information : Olivier what does this mean? # type information : Olivier what does this mean?
# #
def dtype_specs(self): def dtype_specs(self):
"""Return python - C type correspondance tuple for self.data
Return a tuple (python type, c type, numpy typenum) that corresponds to
self.dtype. It is for use in C code generation.
"""
#TODO: add more type correspondances for e.g. int32, int64, float32,
#complex64, etc.
return {'float64': (float, 'double', 'NPY_DOUBLE')}[self.dtype] return {'float64': (float, 'double', 'NPY_DOUBLE')}[self.dtype]
# #
...@@ -578,13 +585,12 @@ if 0: ...@@ -578,13 +585,12 @@ if 0:
sub = _scalar_switch(sub_elemwise, sub_scalar_r, sub_scalar_l) sub = _scalar_switch(sub_elemwise, sub_scalar_r, sub_scalar_l)
sub_inplace = _scalar_switch(sub_elemwise_inplace, sub_scalar_rinplace) sub_inplace = _scalar_switch(sub_elemwise_inplace, sub_scalar_rinplace)
if 1: ##########################
########################## # Arithmetic : Mul
# Arithmetic : Mul ##########################
##########################
# Elemwise # # Elemwise #
class MulElemwise(_Elemwise): class MulElemwise(_Elemwise):
def impl(self, x, y): def impl(self, x, y):
_assert_same_shapes(x, y) _assert_same_shapes(x, y)
return x * y return x * y
...@@ -592,34 +598,34 @@ if 1: ...@@ -592,34 +598,34 @@ if 1:
return mul(y, gz), mul(x, gz) return mul(y, gz), mul(x, gz)
def c_foreach(self, (x_i, y_i), (z_i, )): def c_foreach(self, (x_i, y_i), (z_i, )):
return "z_i = x_i * y_i;" return "z_i = x_i * y_i;"
mul_elemwise = constructor(MulElemwise) mul_elemwise = constructor(MulElemwise)
class MulElemwiseInplace(MulElemwise.inplace_version()): class MulElemwiseInplace(MulElemwise.inplace_version()):
def impl(self, x, y): def impl(self, x, y):
_assert_same_shapes(x, y) _assert_same_shapes(x, y)
x *= y x *= y
return x return x
mul_elemwise_inplace = constructor(MulElemwiseInplace) mul_elemwise_inplace = constructor(MulElemwiseInplace)
# Scalar # # Scalar #
class Scale(TensorScalarOp): class Scale(TensorScalarOp):
def impl(self, x, a): def impl(self, x, a):
_assert_tensor_scalar(x, a) _assert_tensor_scalar(x, a)
return x * a return x * a
def grad(self, (x, a), gz): def grad(self, (x, a), gz):
return scale(a, gz), sum(mul_elemwise(x, gz)) return scale(a, gz), sum(mul_elemwise(x, gz))
c_expr = "x_i * a" c_expr = "x_i * a"
scale = constructor(Scale) scale = constructor(Scale)
class ScaleInplace(Scale.inplace_version()): class ScaleInplace(Scale.inplace_version()):
def impl(self, x, a): def impl(self, x, a):
_assert_tensor_scalar(x, a) _assert_tensor_scalar(x, a)
x *= a x *= a
return x return x
scale_inplace = constructor(ScaleInplace) scale_inplace = constructor(ScaleInplace)
mul = _scalar_switch(mul_elemwise, scale, scale) mul = _scalar_switch(mul_elemwise, scale, scale)
mul_inplace = _scalar_switch(mul_elemwise_inplace, scale_inplace) mul_inplace = _scalar_switch(mul_elemwise_inplace, scale_inplace)
if 0: if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论