bringing back subtensor, fixing add in c

上级 6aa77540
......@@ -4,6 +4,7 @@ import numpy
from copy import copy
import inspect
from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError
import gof.result
from base_tensor import BaseTensor, BaseTensorOp
from elemwise import Elemwise
......@@ -61,8 +62,8 @@ class Tensor(BaseTensor):
T = property(__get_T)
#SLICING
def __getitem__(self, key): raise NotImplementedError()
def __getslice__(self, key): raise NotImplementedError()
def __getitem__(self, item): return subtensor(self, item)
def __getslice__(self, *args): return subtensor(self, slice(*args))
# alternate Tensor constructor
def tinit(data, broadcastable=None, role=None, name=None):
......@@ -113,6 +114,11 @@ def _assert_tensor_scalar(x, a):
if numpy.product(a.shape) != 1:
raise ValueError("The second argument must be a scalar.")
def _as_tensor(obj):
if isinstance(obj, Tensor):
return obj
else:
return tinit(obj)
class _Op(BaseTensorOp):
"""A convenient base for the ops in this file"""
......@@ -121,13 +127,7 @@ class _Op(BaseTensorOp):
@classmethod
def input_wrapper(cls, obj):
if isinstance(obj, Tensor):
return obj
else:
return tinit(obj)
# nin = -1
# nout = 1
return _as_tensor(obj)
# def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype)
......@@ -344,46 +344,76 @@ class TensorCopy(_Elemwise):
return "%(z)s_i = %(x)s_i;"
tensor_copy = _constructor(TensorCopy)
if 0:
##########################
# View Operations
##########################
##########################
# View Operations
##########################
class transpose(_Op, Viewer):
def view_map(self):
return {self.out: [self.inputs[0]]}
def impl(self, x):
return x.T
def grad(self, x, gz):
return transpose_copy(gz)
def propagate_broadcastable(self, x):
rval = list(x)
rval.reverse()
return [rval]
def c_impl(self, x, z):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
"""
class Subtensor(_Op, Viewer):
def view_map(self):
return {self.out: [self.inputs[0]]}
def impl(x, item):
rval = x.__getitem__(item)
#print 'get_slice running', rval
return rval
def grad(x, gz):
# - option: allocate a potentially large matrix of zeros, and fill in
# the appropriate elements from gz
# - option: return a sparse matrix
# - option: return gz, but think about how to include a special addition
# function that uses a matching view over the original data
raise NotImplemented
class Transpose(_Op, Viewer):
def view_map(self):
return {self.out: [self.inputs[0]]}
def propagate_broadcastable(self, x):
rval = list(x)
rval.reverse()
return [rval]
def impl(self, x):
return x.T #numpy's transpose
def grad(self, x, gz):
return transpose_copy(gz)
def c_impl(self, x, z):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
"""
transpose = _constructor(Transpose)
class Subtensor(Op, Viewer):
nin = 2
nout = 1
e_invalid = 'invalid index'
def __init__(self, *args,**kwargs):
def as_tuple_result(obj):
if isinstance(obj, ResultBase):
return obj
r = gof.result.PythonResult(None)
if isinstance(obj, tuple):
r.data = obj
else:
r.data = (obj,)
return r
print 'Subtensor.__init__', args, kwargs
#Olivier says not to call this
#Op.__init__(self, *args,**kwargs)
#Viewer.__init__(self, *args,**kwargs)
t, coord = args
t = _as_tensor(t)
coord = as_tuple_result(coord)
if len(coord.data) != len(t.broadcastable):
raise ValueError(Subtensor.e_invalid)
broadcastable = [0 for c in coord.data if isinstance(c, slice)]
self.inputs = [t, coord]
self.outputs = [Tensor(t.dtype, broadcastable)]
def view_map(self):
return {self.out: [self.inputs[0]]}
def perform(self):
x = self.inputs[0].data
c = self.inputs[1].data
if len(c) == 1:
self.outputs[0].data = x.__getitem__(c[0])
else:
self.outputs[0].data = x.__getitem__(c)
def grad(x, gz):
# - option: allocate a potentially large matrix of zeros, and fill in
# the appropriate elements from gz
# - option: return a sparse matrix
# - option: return gz, but think about how to include a special addition
# function that works on a corresponding view of the original data
raise NotImplementedError()
subtensor = _constructor(Subtensor)
##########################
......@@ -398,7 +428,7 @@ class AddElemwise(_Elemwise):
def grad(self, (x, y), gz):
return gz, gz
def c_foreach(self, (x_i, y_i), (z_i, )):
return "z_i = x_i + y_i;"
return "%(z)s_i = %(x)s_i + %(y)s_i;"
add_elemwise = _constructor(AddElemwise)
class AddElemwiseInplace(AddElemwise.inplace_version()):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论