提交 4b72a709 authored 作者: James Bergstra's avatar James Bergstra

rewrote Subtensor

上级 f9be5a48
...@@ -686,6 +686,8 @@ class T_transpose(unittest.TestCase): ...@@ -686,6 +686,8 @@ class T_transpose(unittest.TestCase):
verify_grad(self, TransposeInplace, [numpy.ones(3)]) verify_grad(self, TransposeInplace, [numpy.ones(3)])
class T_subtensor(unittest.TestCase): class T_subtensor(unittest.TestCase):
def setUp(self):
Subtensor.debug = False
def test0_err_invalid(self): def test0_err_invalid(self):
#it is impossible to retrieve a view of a 0-d tensor #it is impossible to retrieve a view of a 0-d tensor
n = astensor(numpy.ones(())) n = astensor(numpy.ones(()))
...@@ -736,6 +738,7 @@ class T_subtensor(unittest.TestCase): ...@@ -736,6 +738,7 @@ class T_subtensor(unittest.TestCase):
self.failUnless(tval.shape == ()) self.failUnless(tval.shape == ())
self.failUnless(tval == 5.0) self.failUnless(tval == 5.0)
def test1_ok_range_infinite(self): def test1_ok_range_infinite(self):
#Subtensor.debug = True
n = astensor(numpy.ones(3)*5) n = astensor(numpy.ones(3)*5)
t = n[1:] t = n[1:]
self.failUnless(t.owner.__class__ is Subtensor) self.failUnless(t.owner.__class__ is Subtensor)
......
...@@ -38,7 +38,6 @@ def dfs(outputs): ...@@ -38,7 +38,6 @@ def dfs(outputs):
@todo: consider rewriting this function as a generator. @todo: consider rewriting this function as a generator.
""" """
raise Exception('this function has not been tested')
r_set = set() r_set = set()
r_list = list() r_list = list()
def seek(r): def seek(r):
...@@ -64,7 +63,7 @@ def inputs(o): ...@@ -64,7 +63,7 @@ def inputs(o):
Returns the set of inputs necessary to compute the outputs in o Returns the set of inputs necessary to compute the outputs in o
such that input.owner is None. such that input.owner is None.
""" """
return [r in dfs(o) if r.owner is None] return [r for r in dfs(o) if r.owner is None]
def results_and_orphans(i, o, except_unreachable_input=False): def results_and_orphans(i, o, except_unreachable_input=False):
......
...@@ -21,7 +21,7 @@ def as_scalar(x, name = None): ...@@ -21,7 +21,7 @@ def as_scalar(x, name = None):
s.data = x s.data = x
return s return s
if isinstance(x, int): if isinstance(x, int):
s = Scalar('int32', name = name) s = Scalar('int64', name = name)
s.data = x s.data = x
return s return s
if isinstance(x, Scalar): if isinstance(x, Scalar):
......
...@@ -292,8 +292,10 @@ class Tensor(Result): ...@@ -292,8 +292,10 @@ class Tensor(Result):
T = property(lambda self: transpose(self)) T = property(lambda self: transpose(self))
#SLICING #SLICING
def __getitem__(self, item): return subtensor(self, item) def __getitem__(self, args): return Subtensor.from_idxs(self,
def __getslice__(self, *args): return subtensor(self, slice(*args)) args).outputs[0]
def __getslice__(self, *args): return Subtensor.from_idxs(self,
(slice(*args),)).outputs[0]
#COPYING #COPYING
def copy(self): return tensor_copy(self) def copy(self): return tensor_copy(self)
...@@ -576,68 +578,132 @@ transpose_inplace = gof.op.constructor(TransposeInplace) ...@@ -576,68 +578,132 @@ transpose_inplace = gof.op.constructor(TransposeInplace)
def transpose(x, **kwargs): def transpose(x, **kwargs):
return transpose_inplace(tensor_copy(x), **kwargs) return transpose_inplace(tensor_copy(x), **kwargs)
class Subtensor_dx(Op, Viewer):
"""Return a tensor full of zeros, except for what was sliced from x by
Subtensor.
"""
class Subtensor(Op, Viewer): class Subtensor(Op, Viewer):
nin = 2 """Return a subtensor view
nout = 1
This class uses a relatively complex internal representation of the inputs
to remember how the input tensor x should be sliced. The instance variable
idxlist is a list whose elements are either integers, or slices. The
integers are indexes into the inputs array, and the start/stop/step members
of each slice are also integer indexes into the inputs array (or None). The
inputs array is the tensor x, followed by scalar integer results.
"""
e_invalid = 'invalid index' e_invalid = 'invalid index'
debug = 0 debug = 0
def __init__(self, *args,**kwargs):
def as_tuple_result(obj):
if isinstance(obj, Result):
return obj
r = gof.result.PythonResult(None)
if isinstance(obj, tuple):
r.data = obj
else:
r.data = (obj,)
return r
def pad(tplR, N):
l = list(tplR.data)
for i in range(len(l), N):
l.append(slice(0,sys.maxint,1))
tplR.data = tuple(l)
@staticmethod
def from_idxs(x, idxs, **kwargs):
if Subtensor.debug: if Subtensor.debug:
print 'Subtensor.__init__', args, kwargs print idxs, sys.maxint
#Olivier says not to call this
#Op.__init__(self, *args,**kwargs) def asidx(i):
#Viewer.__init__(self, *args,**kwargs) if isinstance(i, int): return scal.constant(i)
t, coord = args if isinstance(i, scal.Scalar) and ('int' in i.dtype): return i
t = _as_tensor(t) raise TypeError(Subtensor.e_invalid, i)
coord = as_tuple_result(coord)
if len(coord.data) > len(t.broadcastable): x = _as_tensor(x)
raise ValueError(Subtensor.e_invalid) idx_list = [] # like args, but with int -> scalar.constant
# add the implicit extra unbounded slices inputs = [x] # like args, but with slices flattened
# e.g. n[0] on a 3d tensor pads to n[0,:,:] if not isinstance(idxs, (list, tuple)):
pad(coord, len(t.broadcastable)) idxs = (idxs,)
broadcastable = [0 for c in coord.data if isinstance(c, slice)]
if Subtensor.debug: for idx in idxs:
print 'brdcstble', broadcastable try:
print 't', t.data ai = asidx(idx)
print 'coord', coord.data idx_list.append(len(inputs))
self.inputs = [t, coord] inputs.append(ai)
self.outputs = [Tensor(t.dtype, broadcastable)] except TypeError:
if isinstance(idx, slice):
start = None if idx.start is None else asidx(idx.start)
stop = None if idx.stop is None else asidx(idx.stop)
step = None if idx.step is None else asidx(idx.step)
# If we get here, then everything got turned (successfully)
# into a scal.Scalar (with integer dtype) or None
if start:
startpos = len(inputs)
inputs.append(start)
else:
startpos = None
if stop:
stoppos = len(inputs)
inputs.append(stop)
else:
stoppos = None
if step:
steppos = len(inputs)
inputs.append(step)
else:
steppos = None
idx_list.append(slice(startpos, stoppos, steppos))
else:
raise
assert len(idxs) == len(idx_list)
return Subtensor( inputs, idx_list, **kwargs)
def __init__(self, inputs, idx_list, **kwargs):
if len(idx_list) > len(inputs[0].broadcastable):
raise ValueError(Subtensor.e_invalid,
(len(idx_list), len(inputs[0].broadcastable)))
#infer the broadcasting pattern
padded = list(idx_list) \
+ [slice(0,sys.maxint,1)] * (len(inputs[0].broadcastable) - len(idx_list))
broadcastable = [False for p in padded if isinstance(p, slice)]
Op.__init__(self, **kwargs)
self.inputs = inputs
self.outputs = [Tensor(self.inputs[0].dtype, broadcastable)]
self.idx_list = idx_list
def view_map(self): def view_map(self):
return {self.out: [self.inputs[0]]} return {self.out: [self.inputs[0]]}
def perform(self): def perform(self):
x = self.inputs[0].data x = self.inputs[0].data
c = self.inputs[1].data cdata = []
if Subtensor.debug: for c in self.idx_list:
print 'perform: x', x if isinstance(c, slice):
print 'perform: c', c cdata.append(slice(
if len(c) == 1: None if c.start is None else self.inputs[c.start].data,
self.outputs[0].data = x.__getitem__(c[0]) None if c.stop is None else self.inputs[c.stop].data,
None if c.step is None else self.inputs[c.step].data))
else:
d = self.inputs[c].data
assert 'int' in str(d.dtype)
cdata.append(d)
if len(cdata) > 1:
cdata = tuple(cdata) #there's a diff between tuples and lists here...
else: else:
self.outputs[0].data = x.__getitem__(c) cdata = cdata[0]
def grad(self, (x,), (gz,)):
self.outputs[0].data = x.__getitem__(cdata)
if Subtensor.debug:
print self.inputs[0].data, cdata, self.outputs[0].data
def grad(self, inputs, (gz,)):
# - option: allocate a potentially large matrix of zeros, and fill in # - option: allocate a potentially large matrix of zeros, and fill in
# the appropriate elements from gz # the appropriate elements from gz
# - option: return a sparse matrix # - option: return a sparse matrix
# - option: return gz, but think about how to include a special addition # - option: return gz, but think about how to include a special addition
# function that works on a corresponding view of the original data # function that works on a corresponding view of the original data
raise NotImplementedError() # - return a Subtensor_dx op, which we will optimize away.
subtensor = gof.op.constructor(Subtensor) return [Subtensor_dx(gz, inputs[0], *self.new_args)] + [None] * (len(inputs)-1)
def clone_with_new_inputs(self, *new_inputs):
assert len(self.inputs) == len(new_inputs)
return Subtensor(new_inputs, self.idx_list)
class VerticalStack(Op): class VerticalStack(Op):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论