more cleanup in tensor.py

上级 14ab328e
...@@ -398,13 +398,9 @@ class T_abs(unittest.TestCase): ...@@ -398,13 +398,9 @@ class T_abs(unittest.TestCase):
verify_grad(self, Abs, [numpy.ones(())]) verify_grad(self, Abs, [numpy.ones(())])
verify_grad(self, Abs, [numpy.ones(3)]) verify_grad(self, Abs, [numpy.ones(3)])
class AbsBadGrad(tensor._Elemwise): class AbsBadGrad(Abs):
def impl(self, x):
return numpy.abs(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return mul(gz * sgn(x),0.9), return mul(gz * sgn(x),0.9),
def c_foreach(self, (x_i, ), (z_i, )):
return "z_i = abs(x_i);"
def test_badgrad(self): def test_badgrad(self):
try: try:
......
"""A simple class to store ndarray data """ """A simple class to store ndarray data """
from gof import ResultBase, Op, utils from gof import ResultBase, Op, utils, AbstractFunctionError
import numpy import numpy
from copy import copy from copy import copy
......
...@@ -9,7 +9,6 @@ import gof.result ...@@ -9,7 +9,6 @@ import gof.result
import gof.op import gof.op
from base_tensor import BaseTensor, BaseTensorOp from base_tensor import BaseTensor, BaseTensorOp
from elemwise import Elemwise
import blas # for gemm, dot import blas # for gemm, dot
import elemwise2 as s2t import elemwise2 as s2t
...@@ -159,78 +158,6 @@ class _Binary: ...@@ -159,78 +158,6 @@ class _Binary:
nin = 2 nin = 2
class _Elemwise(Elemwise, _Op):
@staticmethod
def extract_name(name):
if name.endswith("_i"):
return name[:-2]
else:
return name
@staticmethod
def is_loop_var(name):
return name.endswith("_i")
def var_desc(self):
cls = self.__class__
(self, inames, onames), _1, _2, _3 = inspect.getargspec(cls.c_foreach)
return ([(cls.extract_name(name), cls.is_loop_var(name)) for name in inames],
[(cls.extract_name(name), cls.is_loop_var(name)) for name in onames])
def propagate_broadcastable(self, *inputs):
idesc, odesc = self.var_desc()
nonloop_o = [o[0] for o in odesc if not o[1]]
if nonloop_o:
raise Exception("Cannot infer broadcastable for non-loop variable(s) %s" % nonloop_o)
all_bcast = [broadcastable for broadcastable, i in zip(inputs, idesc) if i[1]]
if reduce(lambda x, y: x is not False and x == y and y, [len(x) for x in all_bcast]) is False:
raise TypeError(_Elemwise.propagate_broadcastable.E_ndim, self.__class__)
ret = []
for arr in zip(*all_bcast):
if 0 in arr:
ret.append(0)
else:
ret.append(1)
return [ret] * self.nout
propagate_broadcastable.E_ndim \
= "Inputs that are loop variables do not all have the same number of dimensions."
def c_init(self, inputs, outputs):
raise AbstractFunctionError()
def c_foreach(self, inputs, outputs):
raise AbstractFunctionError()
def c_finalize(self, inputs, outputs):
raise AbstractFunctionError()
def c_code_init(self):
return self.c_init(self.inputs, self.outputs)
def c_code_foreach(self):
return self.c_foreach(self.inputs, self.outputs)
def c_code_finalize(self):
return self.c_finalize(self.inputs, self.outputs)
class TensorScalarOp(_Elemwise):
def var_desc(self):
return [('x', 1), ('a', 0)], [('z', 1)]
def c_code_init(self):
return """
dtype_%(a)s _%(a)s;
if (PyArray_SIZE(%(a)s) != 1) {
PyErr_SetString(PyExc_ValueError, \"The size of the scalar argument is not 1.\");
%(fail)s
}
_%(a)s = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
"""
def c_code_foreach(self):
return "%%(z)s_i = %s;" % self.c_expr
########################## ##########################
# Unary Operations # Unary Operations
########################## ##########################
...@@ -314,14 +241,14 @@ class TransposeInplace(_Op, Viewer): ...@@ -314,14 +241,14 @@ class TransposeInplace(_Op, Viewer):
def grad(self, x, gz): def grad(self, x, gz):
return transpose(gz) return transpose(gz)
def c_impl(self, x, z): def c_code(self, (x, ), (z, ), sub):
return """ return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL); PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) { if (%(z)s) {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
} }
%(z)s = transposed; %(z)s = transposed;
""" """ % locals()
transpose_inplace = gof.op.constructor(TransposeInplace) transpose_inplace = gof.op.constructor(TransposeInplace)
def transpose(x, **kwargs): def transpose(x, **kwargs):
return transpose_inplace(tensor_copy(x), **kwargs) return transpose_inplace(tensor_copy(x), **kwargs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论