subtensor works

上级 fb352f6b
......@@ -129,6 +129,7 @@ class T_subtensor(unittest.TestCase):
self.failUnless(t.owner.__class__ is Subtensor)
try:
tval = eval_outputs([t])
self.fail()
except Exception, e:
if e[0] != 'index out of bounds':
raise
......@@ -146,62 +147,113 @@ class T_subtensor(unittest.TestCase):
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0)
if 0:
def test1_err_invalid(self):
n = tinit(numpy.ones(1))
try:
t = n[0,0]
self.fail()
except ValueError, e:
self.failUnless(e[0] is Subtensor.e_invalid)
def test1_ok_elem(self):
n = tinit(numpy.ones(1)*5)
t = n[0]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (1,))
self.failUnless(tval[0] == 5.0)
def test1_ok_range_infinite(self):
n = tinit(numpy.ones(3)*5)
t = n[1:]
self.failUnless(t.owner.__class__ is Subtensor)
def test1_err_invalid(self):
n = tinit(numpy.ones(1))
try:
t = n[0,0]
self.fail()
except ValueError, e:
self.failUnless(e[0] is Subtensor.e_invalid)
def test1_ok_elem(self):
n = tinit(numpy.ones(1)*5)
t = n[0]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == ())
self.failUnless(tval == 5.0)
def test1_ok_range_infinite(self):
n = tinit(numpy.ones(3)*5)
t = n[1:]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0)
def test1_ok_strided(self):
n = tinit(numpy.ones(5)*5)
t = n[1::2]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0)
tval = eval_outputs([n[0:-1:2]]) #0 to 1 from the end stepping by 2
self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0)
def test2_err_bounds0(self):
n = tinit(numpy.ones((2,3))*5)
t = n[0,4]
self.failUnless(t.owner.__class__ is Subtensor)
try:
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0)
def test1_ok_strided(self):
n = tinit(numpy.ones(5)*5)
t = n[1::2]
self.failUnless(t.owner.__class__ is Subtensor)
self.fail()
except IndexError, e:
return
def test2_err_bounds1(self):
n = tinit(numpy.ones((2,3))*5)
t = n[4:5,2]
self.failUnless(t.owner.__class__ is Subtensor)
try:
tval = eval_outputs([t])
self.failUnless(tval.shape == (3,))
self.failUnless(tval[1] == 5.0)
except Exception, e:
if e[0] != 'index out of bounds':
raise
def test2_ok_elem(self):
n = tinit(numpy.asarray(range(6)).reshape((2,3)))
t = n[0,2]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == ())
self.failUnless(numpy.all(tval == 2))
def test2_ok_row(self):
n = tinit(numpy.asarray(range(6)).reshape((2,3)))
t = n[1]
self.failIf(any(n.broadcastable))
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (3,))
self.failUnless(numpy.all(tval == [3,4,5]))
tval = eval_outputs([n[1:-1:2]])
self.failUnless(tval.shape == (3,))
self.failUnless(tval[1] == 5.0)
def test2_ok_col(self):
n = tinit(numpy.ones((2,3))*5)
t = n[:,0]
self.failUnless(t.owner.__class__ is Subtensor)
self.failIf(any(n.broadcastable))
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,))
self.failUnless(numpy.all(tval == 5.0))
def test2(self):
raise NotImplementedError() #remember to bring back the rest of tests
if 0:
def test2_err_bounds0(self):
raise NotImplementedError()
def test2_err_bounds1(self):
raise NotImplementedError()
def test2_ok_elem(self):
raise NotImplementedError()
def test2_ok_row(self):
raise NotImplementedError()
def test2_ok_col(self):
raise NotImplementedError()
def test2_ok_rows_finite(self):
raise NotImplementedError()
def test2_ok_cols_infinite(self):
raise NotImplementedError()
def test2_ok_strided(self):
raise NotImplementedError()
def test3_ok_mat(self):
raise NotImplementedError()
def test2_ok_rows_finite(self):
n = tinit(numpy.ones((4,3))*5)
t = n[1:3,0]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,))
self.failUnless(numpy.all(tval == 5.0))
def test2_ok_cols_infinite(self):
n = tinit(numpy.asarray(range(12)).reshape((4,3)))
t = n[1,2:]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (1,))
self.failUnless(numpy.all(tval == 5))
def test2_ok_strided(self):
n = tinit(numpy.asarray(range(20)).reshape((4,5)))
t = n[1:4:2,1:5:2]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,2))
self.failUnless(numpy.all(tval == [[6, 8],[16, 18]]))
def test3_ok_mat(self):
n = tinit(numpy.asarray(range(24)).reshape((2,3,4)))
t = n[0,0,0]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == ())
self.failUnless(numpy.all(tval == 0))
class T_add(unittest.TestCase):
......
"""A ResultBase to store numpy.ndarray with basic accompanying Ops"""
import sys # for sys.maxint
import inspect
import numpy
from copy import copy
import inspect
from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError
import gof.result
......@@ -374,6 +375,7 @@ class Subtensor(Op, Viewer):
nin = 2
nout = 1
e_invalid = 'invalid index'
debug = 0
def __init__(self, *args,**kwargs):
def as_tuple_result(obj):
if isinstance(obj, ResultBase):
......@@ -384,17 +386,30 @@ class Subtensor(Op, Viewer):
else:
r.data = (obj,)
return r
print 'Subtensor.__init__', args, kwargs
def pad(tplR, N):
l = list(tplR.data)
for i in range(len(l), N):
l.append(slice(0,sys.maxint,1))
tplR.data = tuple(l)
if Subtensor.debug:
print 'Subtensor.__init__', args, kwargs
#Olivier says not to call this
#Op.__init__(self, *args,**kwargs)
#Viewer.__init__(self, *args,**kwargs)
t, coord = args
t = _as_tensor(t)
coord = as_tuple_result(coord)
if len(coord.data) != len(t.broadcastable):
if len(coord.data) > len(t.broadcastable):
raise ValueError(Subtensor.e_invalid)
# add the implicit extra unbounded slices
# e.g. n[0] on a 3d tensor pads to n[0,:,:]
pad(coord, len(t.broadcastable))
broadcastable = [0 for c in coord.data if isinstance(c, slice)]
if Subtensor.debug:
print 'brdcstble', broadcastable
print 't', t.data
print 'coord', coord.data
self.inputs = [t, coord]
self.outputs = [Tensor(t.dtype, broadcastable)]
def view_map(self):
......@@ -402,6 +417,9 @@ class Subtensor(Op, Viewer):
def perform(self):
x = self.inputs[0].data
c = self.inputs[1].data
if Subtensor.debug:
print 'perform: x', x
print 'perform: c', c
if len(c) == 1:
self.outputs[0].data = x.__getitem__(c[0])
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论