subtensor works

上级 fb352f6b
...@@ -129,6 +129,7 @@ class T_subtensor(unittest.TestCase): ...@@ -129,6 +129,7 @@ class T_subtensor(unittest.TestCase):
self.failUnless(t.owner.__class__ is Subtensor) self.failUnless(t.owner.__class__ is Subtensor)
try: try:
tval = eval_outputs([t]) tval = eval_outputs([t])
self.fail()
except Exception, e: except Exception, e:
if e[0] != 'index out of bounds': if e[0] != 'index out of bounds':
raise raise
...@@ -146,7 +147,6 @@ class T_subtensor(unittest.TestCase): ...@@ -146,7 +147,6 @@ class T_subtensor(unittest.TestCase):
tval = eval_outputs([t]) tval = eval_outputs([t])
self.failUnless(tval.shape == (2,)) self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0) self.failUnless(tval[1] == 5.0)
if 0:
def test1_err_invalid(self): def test1_err_invalid(self):
n = tinit(numpy.ones(1)) n = tinit(numpy.ones(1))
try: try:
...@@ -159,8 +159,8 @@ class T_subtensor(unittest.TestCase): ...@@ -159,8 +159,8 @@ class T_subtensor(unittest.TestCase):
t = n[0] t = n[0]
self.failUnless(t.owner.__class__ is Subtensor) self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) tval = eval_outputs([t])
self.failUnless(tval.shape == (1,)) self.failUnless(tval.shape == ())
self.failUnless(tval[0] == 5.0) self.failUnless(tval == 5.0)
def test1_ok_range_infinite(self): def test1_ok_range_infinite(self):
n = tinit(numpy.ones(3)*5) n = tinit(numpy.ones(3)*5)
t = n[1:] t = n[1:]
...@@ -173,35 +173,87 @@ class T_subtensor(unittest.TestCase): ...@@ -173,35 +173,87 @@ class T_subtensor(unittest.TestCase):
t = n[1::2] t = n[1::2]
self.failUnless(t.owner.__class__ is Subtensor) self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) tval = eval_outputs([t])
self.failUnless(tval.shape == (3,)) self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0) self.failUnless(tval[1] == 5.0)
tval = eval_outputs([n[1:-1:2]]) tval = eval_outputs([n[0:-1:2]]) #0 to 1 from the end stepping by 2
self.failUnless(tval.shape == (3,)) self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0) self.failUnless(tval[1] == 5.0)
def test2(self):
raise NotImplementedError() #remember to bring back the rest of tests
if 0:
def test2_err_bounds0(self): def test2_err_bounds0(self):
raise NotImplementedError() n = tinit(numpy.ones((2,3))*5)
t = n[0,4]
self.failUnless(t.owner.__class__ is Subtensor)
try:
tval = eval_outputs([t])
self.fail()
except IndexError, e:
return
def test2_err_bounds1(self): def test2_err_bounds1(self):
raise NotImplementedError() n = tinit(numpy.ones((2,3))*5)
t = n[4:5,2]
self.failUnless(t.owner.__class__ is Subtensor)
try:
tval = eval_outputs([t])
except Exception, e:
if e[0] != 'index out of bounds':
raise
def test2_ok_elem(self): def test2_ok_elem(self):
raise NotImplementedError() n = tinit(numpy.asarray(range(6)).reshape((2,3)))
t = n[0,2]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == ())
self.failUnless(numpy.all(tval == 2))
def test2_ok_row(self): def test2_ok_row(self):
raise NotImplementedError() n = tinit(numpy.asarray(range(6)).reshape((2,3)))
t = n[1]
self.failIf(any(n.broadcastable))
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (3,))
self.failUnless(numpy.all(tval == [3,4,5]))
def test2_ok_col(self): def test2_ok_col(self):
raise NotImplementedError() n = tinit(numpy.ones((2,3))*5)
t = n[:,0]
self.failUnless(t.owner.__class__ is Subtensor)
self.failIf(any(n.broadcastable))
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,))
self.failUnless(numpy.all(tval == 5.0))
def test2_ok_rows_finite(self): def test2_ok_rows_finite(self):
raise NotImplementedError() n = tinit(numpy.ones((4,3))*5)
t = n[1:3,0]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,))
self.failUnless(numpy.all(tval == 5.0))
def test2_ok_cols_infinite(self): def test2_ok_cols_infinite(self):
raise NotImplementedError() n = tinit(numpy.asarray(range(12)).reshape((4,3)))
t = n[1,2:]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (1,))
self.failUnless(numpy.all(tval == 5))
def test2_ok_strided(self): def test2_ok_strided(self):
raise NotImplementedError() n = tinit(numpy.asarray(range(20)).reshape((4,5)))
t = n[1:4:2,1:5:2]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == (2,2))
self.failUnless(numpy.all(tval == [[6, 8],[16, 18]]))
def test3_ok_mat(self): def test3_ok_mat(self):
raise NotImplementedError() n = tinit(numpy.asarray(range(24)).reshape((2,3,4)))
t = n[0,0,0]
self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t])
self.failUnless(tval.shape == ())
self.failUnless(numpy.all(tval == 0))
class T_add(unittest.TestCase): class T_add(unittest.TestCase):
......
"""A ResultBase to store numpy.ndarray with basic accompanying Ops""" """A ResultBase to store numpy.ndarray with basic accompanying Ops"""
import sys # for sys.maxint
import inspect
import numpy import numpy
from copy import copy
import inspect
from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError
import gof.result import gof.result
...@@ -374,6 +375,7 @@ class Subtensor(Op, Viewer): ...@@ -374,6 +375,7 @@ class Subtensor(Op, Viewer):
nin = 2 nin = 2
nout = 1 nout = 1
e_invalid = 'invalid index' e_invalid = 'invalid index'
debug = 0
def __init__(self, *args,**kwargs): def __init__(self, *args,**kwargs):
def as_tuple_result(obj): def as_tuple_result(obj):
if isinstance(obj, ResultBase): if isinstance(obj, ResultBase):
...@@ -384,7 +386,13 @@ class Subtensor(Op, Viewer): ...@@ -384,7 +386,13 @@ class Subtensor(Op, Viewer):
else: else:
r.data = (obj,) r.data = (obj,)
return r return r
def pad(tplR, N):
l = list(tplR.data)
for i in range(len(l), N):
l.append(slice(0,sys.maxint,1))
tplR.data = tuple(l)
if Subtensor.debug:
print 'Subtensor.__init__', args, kwargs print 'Subtensor.__init__', args, kwargs
#Olivier says not to call this #Olivier says not to call this
#Op.__init__(self, *args,**kwargs) #Op.__init__(self, *args,**kwargs)
...@@ -392,9 +400,16 @@ class Subtensor(Op, Viewer): ...@@ -392,9 +400,16 @@ class Subtensor(Op, Viewer):
t, coord = args t, coord = args
t = _as_tensor(t) t = _as_tensor(t)
coord = as_tuple_result(coord) coord = as_tuple_result(coord)
if len(coord.data) != len(t.broadcastable): if len(coord.data) > len(t.broadcastable):
raise ValueError(Subtensor.e_invalid) raise ValueError(Subtensor.e_invalid)
# add the implicit extra unbounded slices
# e.g. n[0] on a 3d tensor pads to n[0,:,:]
pad(coord, len(t.broadcastable))
broadcastable = [0 for c in coord.data if isinstance(c, slice)] broadcastable = [0 for c in coord.data if isinstance(c, slice)]
if Subtensor.debug:
print 'brdcstble', broadcastable
print 't', t.data
print 'coord', coord.data
self.inputs = [t, coord] self.inputs = [t, coord]
self.outputs = [Tensor(t.dtype, broadcastable)] self.outputs = [Tensor(t.dtype, broadcastable)]
def view_map(self): def view_map(self):
...@@ -402,6 +417,9 @@ class Subtensor(Op, Viewer): ...@@ -402,6 +417,9 @@ class Subtensor(Op, Viewer):
def perform(self): def perform(self):
x = self.inputs[0].data x = self.inputs[0].data
c = self.inputs[1].data c = self.inputs[1].data
if Subtensor.debug:
print 'perform: x', x
print 'perform: c', c
if len(c) == 1: if len(c) == 1:
self.outputs[0].data = x.__getitem__(c[0]) self.outputs[0].data = x.__getitem__(c[0])
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论