bringing back subtensor, fixing add in c

上级 6aa77540
...@@ -4,6 +4,7 @@ import numpy ...@@ -4,6 +4,7 @@ import numpy
from copy import copy from copy import copy
import inspect import inspect
from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError
import gof.result
from base_tensor import BaseTensor, BaseTensorOp from base_tensor import BaseTensor, BaseTensorOp
from elemwise import Elemwise from elemwise import Elemwise
...@@ -61,8 +62,8 @@ class Tensor(BaseTensor): ...@@ -61,8 +62,8 @@ class Tensor(BaseTensor):
T = property(__get_T) T = property(__get_T)
#SLICING #SLICING
def __getitem__(self, key): raise NotImplementedError() def __getitem__(self, item): return subtensor(self, item)
def __getslice__(self, key): raise NotImplementedError() def __getslice__(self, *args): return subtensor(self, slice(*args))
# alternate Tensor constructor # alternate Tensor constructor
def tinit(data, broadcastable=None, role=None, name=None): def tinit(data, broadcastable=None, role=None, name=None):
...@@ -113,6 +114,11 @@ def _assert_tensor_scalar(x, a): ...@@ -113,6 +114,11 @@ def _assert_tensor_scalar(x, a):
if numpy.product(a.shape) != 1: if numpy.product(a.shape) != 1:
raise ValueError("The second argument must be a scalar.") raise ValueError("The second argument must be a scalar.")
def _as_tensor(obj):
if isinstance(obj, Tensor):
return obj
else:
return tinit(obj)
class _Op(BaseTensorOp): class _Op(BaseTensorOp):
"""A convenient base for the ops in this file""" """A convenient base for the ops in this file"""
...@@ -121,13 +127,7 @@ class _Op(BaseTensorOp): ...@@ -121,13 +127,7 @@ class _Op(BaseTensorOp):
@classmethod @classmethod
def input_wrapper(cls, obj): def input_wrapper(cls, obj):
if isinstance(obj, Tensor): return _as_tensor(obj)
return obj
else:
return tinit(obj)
# nin = -1
# nout = 1
# def upcast(dtype, *dtypes): # def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype) # z = numpy.zeros((), dtype = dtype)
...@@ -344,22 +344,21 @@ class TensorCopy(_Elemwise): ...@@ -344,22 +344,21 @@ class TensorCopy(_Elemwise):
return "%(z)s_i = %(x)s_i;" return "%(z)s_i = %(x)s_i;"
tensor_copy = _constructor(TensorCopy) tensor_copy = _constructor(TensorCopy)
if 0: ##########################
########################## # View Operations
# View Operations ##########################
##########################
class transpose(_Op, Viewer): class Transpose(_Op, Viewer):
def view_map(self): def view_map(self):
return {self.out: [self.inputs[0]]} return {self.out: [self.inputs[0]]}
def impl(self, x):
return x.T
def grad(self, x, gz):
return transpose_copy(gz)
def propagate_broadcastable(self, x): def propagate_broadcastable(self, x):
rval = list(x) rval = list(x)
rval.reverse() rval.reverse()
return [rval] return [rval]
def impl(self, x):
return x.T #numpy's transpose
def grad(self, x, gz):
return transpose_copy(gz)
def c_impl(self, x, z): def c_impl(self, x, z):
return """ return """
...@@ -369,21 +368,52 @@ if 0: ...@@ -369,21 +368,52 @@ if 0:
} }
%(z)s = transposed; %(z)s = transposed;
""" """
transpose = _constructor(Transpose)
class Subtensor(_Op, Viewer): class Subtensor(Op, Viewer):
nin = 2
nout = 1
e_invalid = 'invalid index'
def __init__(self, *args,**kwargs):
def as_tuple_result(obj):
if isinstance(obj, ResultBase):
return obj
r = gof.result.PythonResult(None)
if isinstance(obj, tuple):
r.data = obj
else:
r.data = (obj,)
return r
print 'Subtensor.__init__', args, kwargs
#Olivier says not to call this
#Op.__init__(self, *args,**kwargs)
#Viewer.__init__(self, *args,**kwargs)
t, coord = args
t = _as_tensor(t)
coord = as_tuple_result(coord)
if len(coord.data) != len(t.broadcastable):
raise ValueError(Subtensor.e_invalid)
broadcastable = [0 for c in coord.data if isinstance(c, slice)]
self.inputs = [t, coord]
self.outputs = [Tensor(t.dtype, broadcastable)]
def view_map(self): def view_map(self):
return {self.out: [self.inputs[0]]} return {self.out: [self.inputs[0]]}
def impl(x, item): def perform(self):
rval = x.__getitem__(item) x = self.inputs[0].data
#print 'get_slice running', rval c = self.inputs[1].data
return rval if len(c) == 1:
self.outputs[0].data = x.__getitem__(c[0])
else:
self.outputs[0].data = x.__getitem__(c)
def grad(x, gz): def grad(x, gz):
# - option: allocate a potentially large matrix of zeros, and fill in # - option: allocate a potentially large matrix of zeros, and fill in
# the appropriate elements from gz # the appropriate elements from gz
# - option: return a sparse matrix # - option: return a sparse matrix
# - option: return gz, but think about how to include a special addition # - option: return gz, but think about how to include a special addition
# function that uses a matching view over the original data # function that works on a corresponding view of the original data
raise NotImplemented raise NotImplementedError()
subtensor = _constructor(Subtensor)
########################## ##########################
...@@ -398,7 +428,7 @@ class AddElemwise(_Elemwise): ...@@ -398,7 +428,7 @@ class AddElemwise(_Elemwise):
def grad(self, (x, y), gz): def grad(self, (x, y), gz):
return gz, gz return gz, gz
def c_foreach(self, (x_i, y_i), (z_i, )): def c_foreach(self, (x_i, y_i), (z_i, )):
return "z_i = x_i + y_i;" return "%(z)s_i = %(x)s_i + %(y)s_i;"
add_elemwise = _constructor(AddElemwise) add_elemwise = _constructor(AddElemwise)
class AddElemwiseInplace(AddElemwise.inplace_version()): class AddElemwiseInplace(AddElemwise.inplace_version()):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论