subtensor works

上级 fb352f6b
...@@ -129,6 +129,7 @@ class T_subtensor(unittest.TestCase): ...@@ -129,6 +129,7 @@ class T_subtensor(unittest.TestCase):
self.failUnless(t.owner.__class__ is Subtensor) self.failUnless(t.owner.__class__ is Subtensor)
try: try:
tval = eval_outputs([t]) tval = eval_outputs([t])
self.fail()
except Exception, e: except Exception, e:
if e[0] != 'index out of bounds': if e[0] != 'index out of bounds':
raise raise
...@@ -146,62 +147,113 @@ class T_subtensor(unittest.TestCase): ...@@ -146,62 +147,113 @@ class T_subtensor(unittest.TestCase):
tval = eval_outputs([t]) tval = eval_outputs([t])
self.failUnless(tval.shape == (2,)) self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0) self.failUnless(tval[1] == 5.0)
if 0: def test1_err_invalid(self):
def test1_err_invalid(self): n = tinit(numpy.ones(1))
n = tinit(numpy.ones(1)) try:
try: t = n[0,0]
t = n[0,0] self.fail()
self.fail() except ValueError, e:
except ValueError, e: self.failUnless(e[0] is Subtensor.e_invalid)
self.failUnless(e[0] is Subtensor.e_invalid) def test1_ok_elem(self):
def test1_ok_elem(self): n = tinit(numpy.ones(1)*5)
n = tinit(numpy.ones(1)*5) t = n[0]
t = n[0] self.failUnless(t.owner.__class__ is Subtensor)
self.failUnless(t.owner.__class__ is Subtensor) tval = eval_outputs([t])
tval = eval_outputs([t]) self.failUnless(tval.shape == ())
self.failUnless(tval.shape == (1,)) self.failUnless(tval == 5.0)
self.failUnless(tval[0] == 5.0) def test1_ok_range_infinite(self):
def test1_ok_range_infinite(self): n = tinit(numpy.ones(3)*5)
n = tinit(numpy.ones(3)*5) t = n[1:]
t = n[1:] self.failUnless(t.owner.__class__ is Subtensor)
self.failUnless(t.owner.__class__ is Subtensor) tval = eval_outputs([t])
self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0)
def test1_ok_strided(self):
n = tinit(numpy.ones(5)*5)
t = n[1::2]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0)
tval = eval_outputs([n[0:-1:2]]) #0 to 1 from the end stepping by 2
self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0)
def test2_err_bounds0(self):
n = tinit(numpy.ones((2,3))*5)
t = n[0,4]
self.failUnless(t.owner.__class__ is Subtensor)
try:
tval = eval_outputs([t]) tval = eval_outputs([t])
self.failUnless(tval.shape == (2,)) self.fail()
self.failUnless(tval[1] == 5.0) except IndexError, e:
def test1_ok_strided(self): return
n = tinit(numpy.ones(5)*5) def test2_err_bounds1(self):
t = n[1::2] n = tinit(numpy.ones((2,3))*5)
self.failUnless(t.owner.__class__ is Subtensor) t = n[4:5,2]
self.failUnless(t.owner.__class__ is Subtensor)
try:
tval = eval_outputs([t]) tval = eval_outputs([t])
self.failUnless(tval.shape == (3,)) except Exception, e:
self.failUnless(tval[1] == 5.0) if e[0] != 'index out of bounds':
raise
def test2_ok_elem(self):
n = tinit(numpy.asarray(range(6)).reshape((2,3)))
t = n[0,2]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == ())
self.failUnless(numpy.all(tval == 2))
def test2_ok_row(self):
n = tinit(numpy.asarray(range(6)).reshape((2,3)))
t = n[1]
self.failIf(any(n.broadcastable))
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (3,))
self.failUnless(numpy.all(tval == [3,4,5]))
tval = eval_outputs([n[1:-1:2]]) def test2_ok_col(self):
self.failUnless(tval.shape == (3,)) n = tinit(numpy.ones((2,3))*5)
self.failUnless(tval[1] == 5.0) t = n[:,0]
self.failUnless(t.owner.__class__ is Subtensor)
self.failIf(any(n.broadcastable))
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,))
self.failUnless(numpy.all(tval == 5.0))
def test2(self): def test2_ok_rows_finite(self):
raise NotImplementedError() #remember to bring back the rest of tests n = tinit(numpy.ones((4,3))*5)
if 0: t = n[1:3,0]
def test2_err_bounds0(self): self.failUnless(t.owner.__class__ is Subtensor)
raise NotImplementedError() tval = eval_outputs([t])
def test2_err_bounds1(self): self.failUnless(tval.shape == (2,))
raise NotImplementedError() self.failUnless(numpy.all(tval == 5.0))
def test2_ok_elem(self):
raise NotImplementedError() def test2_ok_cols_infinite(self):
def test2_ok_row(self): n = tinit(numpy.asarray(range(12)).reshape((4,3)))
raise NotImplementedError() t = n[1,2:]
def test2_ok_col(self): self.failUnless(t.owner.__class__ is Subtensor)
raise NotImplementedError() tval = eval_outputs([t])
def test2_ok_rows_finite(self): self.failUnless(tval.shape == (1,))
raise NotImplementedError() self.failUnless(numpy.all(tval == 5))
def test2_ok_cols_infinite(self):
raise NotImplementedError() def test2_ok_strided(self):
def test2_ok_strided(self): n = tinit(numpy.asarray(range(20)).reshape((4,5)))
raise NotImplementedError() t = n[1:4:2,1:5:2]
self.failUnless(t.owner.__class__ is Subtensor)
def test3_ok_mat(self): tval = eval_outputs([t])
raise NotImplementedError() self.failUnless(tval.shape == (2,2))
self.failUnless(numpy.all(tval == [[6, 8],[16, 18]]))
def test3_ok_mat(self):
n = tinit(numpy.asarray(range(24)).reshape((2,3,4)))
t = n[0,0,0]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == ())
self.failUnless(numpy.all(tval == 0))
class T_add(unittest.TestCase): class T_add(unittest.TestCase):
......
"""A ResultBase to store numpy.ndarray with basic accompanying Ops""" """A ResultBase to store numpy.ndarray with basic accompanying Ops"""
import sys # for sys.maxint
import inspect
import numpy import numpy
from copy import copy
import inspect
from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError
import gof.result import gof.result
...@@ -374,6 +375,7 @@ class Subtensor(Op, Viewer): ...@@ -374,6 +375,7 @@ class Subtensor(Op, Viewer):
nin = 2 nin = 2
nout = 1 nout = 1
e_invalid = 'invalid index' e_invalid = 'invalid index'
debug = 0
def __init__(self, *args,**kwargs): def __init__(self, *args,**kwargs):
def as_tuple_result(obj): def as_tuple_result(obj):
if isinstance(obj, ResultBase): if isinstance(obj, ResultBase):
...@@ -384,17 +386,30 @@ class Subtensor(Op, Viewer): ...@@ -384,17 +386,30 @@ class Subtensor(Op, Viewer):
else: else:
r.data = (obj,) r.data = (obj,)
return r return r
def pad(tplR, N):
print 'Subtensor.__init__', args, kwargs l = list(tplR.data)
for i in range(len(l), N):
l.append(slice(0,sys.maxint,1))
tplR.data = tuple(l)
if Subtensor.debug:
print 'Subtensor.__init__', args, kwargs
#Olivier says not to call this #Olivier says not to call this
#Op.__init__(self, *args,**kwargs) #Op.__init__(self, *args,**kwargs)
#Viewer.__init__(self, *args,**kwargs) #Viewer.__init__(self, *args,**kwargs)
t, coord = args t, coord = args
t = _as_tensor(t) t = _as_tensor(t)
coord = as_tuple_result(coord) coord = as_tuple_result(coord)
if len(coord.data) != len(t.broadcastable): if len(coord.data) > len(t.broadcastable):
raise ValueError(Subtensor.e_invalid) raise ValueError(Subtensor.e_invalid)
# add the implicit extra unbounded slices
# e.g. n[0] on a 3d tensor pads to n[0,:,:]
pad(coord, len(t.broadcastable))
broadcastable = [0 for c in coord.data if isinstance(c, slice)] broadcastable = [0 for c in coord.data if isinstance(c, slice)]
if Subtensor.debug:
print 'brdcstble', broadcastable
print 't', t.data
print 'coord', coord.data
self.inputs = [t, coord] self.inputs = [t, coord]
self.outputs = [Tensor(t.dtype, broadcastable)] self.outputs = [Tensor(t.dtype, broadcastable)]
def view_map(self): def view_map(self):
...@@ -402,6 +417,9 @@ class Subtensor(Op, Viewer): ...@@ -402,6 +417,9 @@ class Subtensor(Op, Viewer):
def perform(self): def perform(self):
x = self.inputs[0].data x = self.inputs[0].data
c = self.inputs[1].data c = self.inputs[1].data
if Subtensor.debug:
print 'perform: x', x
print 'perform: c', c
if len(c) == 1: if len(c) == 1:
self.outputs[0].data = x.__getitem__(c[0]) self.outputs[0].data = x.__getitem__(c[0])
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论