提交 3fbb330a authored 作者: James Bergstra's avatar James Bergstra

merged

...@@ -759,6 +759,16 @@ class T_subtensor(unittest.TestCase): ...@@ -759,6 +759,16 @@ class T_subtensor(unittest.TestCase):
raise raise
return return
self.fail() self.fail()
def test1_err_subslice(self):
n = as_tensor(numpy.ones(3))
try:
t = n[slice(0,slice(1,2,None),None)]
except Exception, e:
if e[0] != Subtensor.e_indextype:
raise
return
self.fail()
def test1_ok_range_finite(self): def test1_ok_range_finite(self):
n = as_tensor(numpy.ones(3)*5) n = as_tensor(numpy.ones(3)*5)
t = n[0:2] t = n[0:2]
...@@ -923,6 +933,16 @@ class T_Stack(unittest.TestCase): ...@@ -923,6 +933,16 @@ class T_Stack(unittest.TestCase):
c = numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) c = numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
self.failUnless((eval_outputs([s]) == c).all()) self.failUnless((eval_outputs([s]) == c).all())
def test_vstack_grad(self):
a = as_tensor(numpy.array([[1, 2, 3], [4, 5, 6]]))
b = as_tensor(numpy.array([[7, 8, 9]]))
s = vertical_stack(a, b)
ga,gb = grad(sum(vertical_stack(a,b)), [a,b])
gval = eval_outputs([ga, gb])
self.failUnless(numpy.all(gval[0] == 1.0))
self.failUnless(numpy.all(gval[1] == 1.0))
class _test_comparison(unittest.TestCase): class _test_comparison(unittest.TestCase):
def test_gt(self): def test_gt(self):
...@@ -1679,7 +1699,10 @@ class _test_grad(unittest.TestCase): ...@@ -1679,7 +1699,10 @@ class _test_grad(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
if 0:
unittest.main() unittest.main()
# suite = unittest.TestLoader() else:
# suite = suite.loadTestsFromTestCase(T_Cast) suite = unittest.TestLoader()
# unittest.TextTestRunner(verbosity=2).run(suite) #suite = suite.loadTestsFromTestCase(T_subtensor)
suite = suite.loadTestsFromTestCase(T_Stack)
unittest.TextTestRunner(verbosity=2).run(suite)
...@@ -431,11 +431,11 @@ tensor_from_scalar = TensorFromScalar() ...@@ -431,11 +431,11 @@ tensor_from_scalar = TensorFromScalar()
class ScalarFromTensor(Op): class ScalarFromTensor(Op):
def make_node(self, t): def make_node(self, t):
assert isinstance(t.type, scal.Tensor) assert isinstance(t.type, Tensor)
assert t.type.broadcastable == () assert t.type.broadcastable == ()
return Apply(self, return Apply(self,
[s], [t],
[scal.Scalar(dtype = s.type.dtype).make_result()]) [scal.Scalar(dtype = t.type.dtype).make_result()])
def perform(self, node, (s, ), (out, )): def perform(self, node, (s, ), (out, )):
out[0] = s.flatten()[0] out[0] = s.flatten()[0]
def grad(self, (s,), (dt,)): def grad(self, (s,), (dt,)):
...@@ -679,6 +679,8 @@ class Subtensor(Op): ...@@ -679,6 +679,8 @@ class Subtensor(Op):
@todo: add support for advanced tensor indexing (in Subtensor_dx too). @todo: add support for advanced tensor indexing (in Subtensor_dx too).
""" """
e_invalid = 'The index list is longer than the number of dimensions of the tensor.' e_invalid = 'The index list is longer than the number of dimensions of the tensor.'
e_subslice = 'nested slicing is not supported'
e_indextype = "Invalid index type or slice for Subtensor"
debug = 0 debug = 0
view_map = {0: [0]} view_map = {0: [0]}
...@@ -698,27 +700,38 @@ class Subtensor(Op): ...@@ -698,27 +700,38 @@ class Subtensor(Op):
return ret return ret
def __init__(self, idx_list): def __init__(self, idx_list):
def convert(entry): def convert(entry, slice_ok=True):
if isinstance(entry, gof.Result) and entry.type == scal.int64: scal_types =[scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [bscalar, iscalar, lscalar]
if isinstance(entry, gof.Result) and entry.type in scal_types:
return entry.type return entry.type
elif isinstance(entry, gof.Type) and entry == scal.int64: elif isinstance(entry, gof.Type) and entry in scal_types:
return entry return entry
elif isinstance(entry, slice): if isinstance(entry, gof.Result) and entry.type in tensor_types:
return scal.Scalar(entry.type.dtype)
elif isinstance(entry, gof.Type) and entry in tensor_types:
return scal.Scalar(entry.dtype)
elif slice_ok and isinstance(entry, slice):
a = entry.start a = entry.start
b = entry.stop b = entry.stop
c = entry.step c = entry.step
return slice(convert(a) if a is not None else None, return slice(convert(a, False) if a is not None else None,
convert(b) if b is not None else None, convert(b, False) if b is not None else None,
convert(c) if c is not None else None) convert(c, False) if c is not None else None)
elif isinstance(entry, int): elif isinstance(entry, int):
return entry return entry
else: else:
raise TypeError("Invalid index type or slice for Subtensor", entry) raise TypeError(Subtensor.e_indextype, entry)
self.idx_list = map(convert, idx_list) self.idx_list = map(convert, idx_list)
def make_node(self, x, *inputs): def make_node(self, x, *inputs):
x = as_tensor(x) x = as_tensor(x)
inputs = tuple(map(scal.as_scalar, inputs)) def my_as_scalar(a):
if isinstance(a, gof.Result) and isinstance(a.type, Tensor):
return scalar_from_tensor(a)
else:
return scal.as_scalar(a)
inputs = tuple(my_as_scalar(a) for a in inputs)
idx_list = list(self.idx_list) idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim: if len(idx_list) > x.type.ndim:
......
...@@ -98,7 +98,7 @@ class NumpyGenerator(gof.op.Op): ...@@ -98,7 +98,7 @@ class NumpyGenerator(gof.op.Op):
and self.ndim == other.ndim \ and self.ndim == other.ndim \
and self.fn == other.fn and self.fn == other.fn
def __hash__(self): def __hash__(self):
return self.seed ^ self.ndim ^ id(self.fn) return self.seed ^ self.ndim ^ hash(self.fn)
def make_node(self, _shape): def make_node(self, _shape):
#TODO: check for constant shape, and guess the broadcastable bits #TODO: check for constant shape, and guess the broadcastable bits
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论