提交 3fbb330a authored 作者: James Bergstra's avatar James Bergstra

merged

......@@ -759,6 +759,16 @@ class T_subtensor(unittest.TestCase):
raise
return
self.fail()
def test1_err_subslice(self):
n = as_tensor(numpy.ones(3))
try:
t = n[slice(0,slice(1,2,None),None)]
except Exception, e:
if e[0] != Subtensor.e_indextype:
raise
return
self.fail()
def test1_ok_range_finite(self):
n = as_tensor(numpy.ones(3)*5)
t = n[0:2]
......@@ -923,6 +933,16 @@ class T_Stack(unittest.TestCase):
c = numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
self.failUnless((eval_outputs([s]) == c).all())
def test_vstack_grad(self):
a = as_tensor(numpy.array([[1, 2, 3], [4, 5, 6]]))
b = as_tensor(numpy.array([[7, 8, 9]]))
s = vertical_stack(a, b)
ga,gb = grad(sum(vertical_stack(a,b)), [a,b])
gval = eval_outputs([ga, gb])
self.failUnless(numpy.all(gval[0] == 1.0))
self.failUnless(numpy.all(gval[1] == 1.0))
class _test_comparison(unittest.TestCase):
def test_gt(self):
......@@ -1679,7 +1699,10 @@ class _test_grad(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
# suite = unittest.TestLoader()
# suite = suite.loadTestsFromTestCase(T_Cast)
# unittest.TextTestRunner(verbosity=2).run(suite)
if 0:
unittest.main()
else:
suite = unittest.TestLoader()
#suite = suite.loadTestsFromTestCase(T_subtensor)
suite = suite.loadTestsFromTestCase(T_Stack)
unittest.TextTestRunner(verbosity=2).run(suite)
......@@ -431,11 +431,11 @@ tensor_from_scalar = TensorFromScalar()
class ScalarFromTensor(Op):
def make_node(self, t):
assert isinstance(t.type, scal.Tensor)
assert isinstance(t.type, Tensor)
assert t.type.broadcastable == ()
return Apply(self,
[s],
[scal.Scalar(dtype = s.type.dtype).make_result()])
[t],
[scal.Scalar(dtype = t.type.dtype).make_result()])
def perform(self, node, (s, ), (out, )):
out[0] = s.flatten()[0]
def grad(self, (s,), (dt,)):
......@@ -679,6 +679,8 @@ class Subtensor(Op):
@todo: add support for advanced tensor indexing (in Subtensor_dx too).
"""
e_invalid = 'The index list is longer than the number of dimensions of the tensor.'
e_subslice = 'nested slicing is not supported'
e_indextype = "Invalid index type or slice for Subtensor"
debug = 0
view_map = {0: [0]}
......@@ -698,27 +700,38 @@ class Subtensor(Op):
return ret
def __init__(self, idx_list):
def convert(entry):
if isinstance(entry, gof.Result) and entry.type == scal.int64:
def convert(entry, slice_ok=True):
scal_types =[scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [bscalar, iscalar, lscalar]
if isinstance(entry, gof.Result) and entry.type in scal_types:
return entry.type
elif isinstance(entry, gof.Type) and entry == scal.int64:
elif isinstance(entry, gof.Type) and entry in scal_types:
return entry
elif isinstance(entry, slice):
if isinstance(entry, gof.Result) and entry.type in tensor_types:
return scal.Scalar(entry.type.dtype)
elif isinstance(entry, gof.Type) and entry in tensor_types:
return scal.Scalar(entry.dtype)
elif slice_ok and isinstance(entry, slice):
a = entry.start
b = entry.stop
c = entry.step
return slice(convert(a) if a is not None else None,
convert(b) if b is not None else None,
convert(c) if c is not None else None)
return slice(convert(a, False) if a is not None else None,
convert(b, False) if b is not None else None,
convert(c, False) if c is not None else None)
elif isinstance(entry, int):
return entry
else:
raise TypeError("Invalid index type or slice for Subtensor", entry)
raise TypeError(Subtensor.e_indextype, entry)
self.idx_list = map(convert, idx_list)
def make_node(self, x, *inputs):
x = as_tensor(x)
inputs = tuple(map(scal.as_scalar, inputs))
def my_as_scalar(a):
if isinstance(a, gof.Result) and isinstance(a.type, Tensor):
return scalar_from_tensor(a)
else:
return scal.as_scalar(a)
inputs = tuple(my_as_scalar(a) for a in inputs)
idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim:
......
......@@ -98,7 +98,7 @@ class NumpyGenerator(gof.op.Op):
and self.ndim == other.ndim \
and self.fn == other.fn
def __hash__(self):
return self.seed ^ self.ndim ^ id(self.fn)
return self.seed ^ self.ndim ^ hash(self.fn)
def make_node(self, _shape):
#TODO: check for constant shape, and guess the broadcastable bits
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论