moved comments, added gradient test for log2

上级 eea1ba8e
...@@ -441,6 +441,9 @@ class T_div(unittest.TestCase): ...@@ -441,6 +441,9 @@ class T_div(unittest.TestCase):
verify_grad(self, DivElemwise, [numpy.random.rand(3), numpy.ones(3)]) verify_grad(self, DivElemwise, [numpy.random.rand(3), numpy.ones(3)])
verify_grad(self, DivElemwise, [numpy.random.rand(3,5), numpy.random.rand(3,5)+0.1]) verify_grad(self, DivElemwise, [numpy.random.rand(3,5), numpy.random.rand(3,5)+0.1])
class T_log2(unittest.TestCase):
def test0(self):
verify_grad(self, Log2, [numpy.random.rand(3,1)+0.0001])
class T_pow(unittest.TestCase): class T_pow(unittest.TestCase):
def setUp(self): def setUp(self):
......
...@@ -130,31 +130,6 @@ class _Op(BaseTensorOp): ...@@ -130,31 +130,6 @@ class _Op(BaseTensorOp):
def input_wrapper(cls, obj): def input_wrapper(cls, obj):
return _as_tensor(obj) return _as_tensor(obj)
# def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype)
# for dtype in dtypes:
# z = z + numpy.zeros((), dtype = dtype)
# return str(z.dtype)
# for dtype in i_dtypes:
# if dtype is None:
# raise TypeError("Expected a Tensor.")
# upcasted = upcast(*i_dtypes)
# return [upcasted] * self.nout
# # try:
# # dmap = self.destroy_map()
# # except AttributeError:
# # dmap = {}
# # rval = []
# # for i in xrange(self.nout):
# # if i in dmap:
# # destroyed = dmap[output]
# # if len(destroyed) != 1:
# # raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# # rval.append(destroyed[0])
# # else:
# # rval.append(upcasted)
# # return rval
def impl(self, *inputs): def impl(self, *inputs):
raise AbstractFunctionError() raise AbstractFunctionError()
...@@ -811,3 +786,28 @@ if 0: ...@@ -811,3 +786,28 @@ if 0:
return t return t
# def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype)
# for dtype in dtypes:
# z = z + numpy.zeros((), dtype = dtype)
# return str(z.dtype)
# for dtype in i_dtypes:
# if dtype is None:
# raise TypeError("Expected a Tensor.")
# upcasted = upcast(*i_dtypes)
# return [upcasted] * self.nout
# # try:
# # dmap = self.destroy_map()
# # except AttributeError:
# # dmap = {}
# # rval = []
# # for i in xrange(self.nout):
# # if i in dmap:
# # destroyed = dmap[output]
# # if len(destroyed) != 1:
# # raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# # rval.append(destroyed[0])
# # else:
# # rval.append(upcasted)
# # return rval
...@@ -62,98 +62,3 @@ class Dot(TensorOp): ...@@ -62,98 +62,3 @@ class Dot(TensorOp):
class Min:
pass
class Max:
pass
class Argmin:
pass
class Argmax:
pass
class MinMax:
pass
# nout = 2
# def impl(x):
# return x.min, x.max
# def specs(x):
# return [(numpy.ndarray, x[1], ())] * 2
# # def alloc((x, ), (_min, _max)):
# # _min.data = numpy.ndarray((), x.dtype)
# # _max.data = numpy.ndarray((), x.dtype)
# def c_init((x, ), (_min, _max)):
# raise NotImplementedError
# return """
# _x_dtype min = _x[0];
# _x_dtype max = _x[0];
# """
# def c_foreach((x, ), (_min, _max)):
# return """
# if (x < min) min = x;
# if (x > max) max = x;
# """
# def c_finalize((x, ), (_min, _max)):
# return """
# _min[0] = min;
# _max[0] = max;
# """
# class Transpose(UnaryTensorOp):
# def propagate_broadcastable(self, x):
# x2 = copy(x)
# x2.reverse()
# return [x2]
# def impl(self, x):
# return x.T
# def c_impl(self, x, z):
# return """
# PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
# //if (PyArray_REFCOUNT(transposed) == 1) {
# // printf("lala\\n");
# //}
# //if (%(z)s) {
# // Py_XDECREF(%(z)s);
# //}
# %(z)s = transposed;
# Py_XINCREF(%(z)s);
# """
# # class Transpose(UnaryTensorOp):
# # def propagate_broadcastable(self, x):
# # x2 = copy(x)
# # x2.reverse()
# # return [x2]
# # def impl(self, x):
# # return x.T
# # def c_impl(self, x, z):
# # return """
# # PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
# # //if (PyArray_REFCOUNT(transposed) == 1) {
# # // printf("lala\\n");
# # //}
# # //if (%(z)s) {
# # // Py_XDECREF(%(z)s);
# # //}
# # %(z)s = transposed;
# # Py_XINCREF(%(z)s);
# # """
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论