more cleanup in tensor.py

上级 14ab328e
......@@ -398,13 +398,9 @@ class T_abs(unittest.TestCase):
verify_grad(self, Abs, [numpy.ones(())])
verify_grad(self, Abs, [numpy.ones(3)])
class AbsBadGrad(tensor._Elemwise):
def impl(self, x):
return numpy.abs(x)
class AbsBadGrad(Abs):
def grad(self, (x, ), (gz, )):
return mul(gz * sgn(x),0.9),
def c_foreach(self, (x_i, ), (z_i, )):
return "z_i = abs(x_i);"
def test_badgrad(self):
try:
......
"""A simple class to store ndarray data """
from gof import ResultBase, Op, utils
from gof import ResultBase, Op, utils, AbstractFunctionError
import numpy
from copy import copy
......
......@@ -9,7 +9,6 @@ import gof.result
import gof.op
from base_tensor import BaseTensor, BaseTensorOp
from elemwise import Elemwise
import blas # for gemm, dot
import elemwise2 as s2t
......@@ -159,78 +158,6 @@ class _Binary:
nin = 2
class _Elemwise(Elemwise, _Op):
@staticmethod
def extract_name(name):
if name.endswith("_i"):
return name[:-2]
else:
return name
@staticmethod
def is_loop_var(name):
return name.endswith("_i")
def var_desc(self):
cls = self.__class__
(self, inames, onames), _1, _2, _3 = inspect.getargspec(cls.c_foreach)
return ([(cls.extract_name(name), cls.is_loop_var(name)) for name in inames],
[(cls.extract_name(name), cls.is_loop_var(name)) for name in onames])
def propagate_broadcastable(self, *inputs):
idesc, odesc = self.var_desc()
nonloop_o = [o[0] for o in odesc if not o[1]]
if nonloop_o:
raise Exception("Cannot infer broadcastable for non-loop variable(s) %s" % nonloop_o)
all_bcast = [broadcastable for broadcastable, i in zip(inputs, idesc) if i[1]]
if reduce(lambda x, y: x is not False and x == y and y, [len(x) for x in all_bcast]) is False:
raise TypeError(_Elemwise.propagate_broadcastable.E_ndim, self.__class__)
ret = []
for arr in zip(*all_bcast):
if 0 in arr:
ret.append(0)
else:
ret.append(1)
return [ret] * self.nout
propagate_broadcastable.E_ndim \
= "Inputs that are loop variables do not all have the same number of dimensions."
def c_init(self, inputs, outputs):
raise AbstractFunctionError()
def c_foreach(self, inputs, outputs):
raise AbstractFunctionError()
def c_finalize(self, inputs, outputs):
raise AbstractFunctionError()
def c_code_init(self):
return self.c_init(self.inputs, self.outputs)
def c_code_foreach(self):
return self.c_foreach(self.inputs, self.outputs)
def c_code_finalize(self):
return self.c_finalize(self.inputs, self.outputs)
class TensorScalarOp(_Elemwise):
def var_desc(self):
return [('x', 1), ('a', 0)], [('z', 1)]
def c_code_init(self):
return """
dtype_%(a)s _%(a)s;
if (PyArray_SIZE(%(a)s) != 1) {
PyErr_SetString(PyExc_ValueError, \"The size of the scalar argument is not 1.\");
%(fail)s
}
_%(a)s = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
"""
def c_code_foreach(self):
return "%%(z)s_i = %s;" % self.c_expr
##########################
# Unary Operations
##########################
......@@ -314,14 +241,14 @@ class TransposeInplace(_Op, Viewer):
def grad(self, x, gz):
return transpose(gz)
def c_impl(self, x, z):
def c_code(self, (x, ), (z, ), sub):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
"""
""" % locals()
transpose_inplace = gof.op.constructor(TransposeInplace)
def transpose(x, **kwargs):
return transpose_inplace(tensor_copy(x), **kwargs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论