提交 950015e0 authored 作者: James Bergstra's avatar James Bergstra

added subtensor_dx

上级 6bb0be5c
...@@ -688,6 +688,8 @@ class T_transpose(unittest.TestCase): ...@@ -688,6 +688,8 @@ class T_transpose(unittest.TestCase):
class T_subtensor(unittest.TestCase): class T_subtensor(unittest.TestCase):
def setUp(self): def setUp(self):
Subtensor.debug = False Subtensor.debug = False
numpy.random.seed(12353123)
def test0_err_invalid(self): def test0_err_invalid(self):
#it is impossible to retrieve a view of a 0-d tensor #it is impossible to retrieve a view of a 0-d tensor
n = astensor(numpy.ones(())) n = astensor(numpy.ones(()))
...@@ -833,6 +835,31 @@ class T_subtensor(unittest.TestCase): ...@@ -833,6 +835,31 @@ class T_subtensor(unittest.TestCase):
self.failUnless(numpy.all(tval == 0)) self.failUnless(numpy.all(tval == 0))
def test_grad_1d(self):
n = astensor(numpy.random.rand(2,3))
z = scal.constant(0)
t = n[z:,z]
gn = gradient.grad(sum(exp(t)), n)
gval = eval_outputs([gn])
s0 = 'array([ 2.05362099, 0. , 0. ])'
s1 = 'array([ 1.55009327, 0. , 0. ])'
self.failUnless(repr(gval[0,:]) == s0)
self.failUnless(repr(gval[1,:]) == s1)
def test_grad_0d(self):
n = astensor(numpy.random.rand(2,3))
t = n[1,0]
gn = gradient.grad(sum(exp(t)), n)
gval = eval_outputs([gn])
g0 = repr(gval[0,:])
g1 = repr(gval[1,:])
s0 = 'array([ 0., 0., 0.])'
s1 = 'array([ 1.55009327, 0. , 0. ])'
self.failUnless(g0 == s0, (g0, s0))
self.failUnless(g1 == s1, (g1, s1))
class T_Stack(unittest.TestCase): class T_Stack(unittest.TestCase):
def test_hstack(self): def test_hstack(self):
a = astensor(numpy.array([[1, 2, 3], [4, 5, 6]]), broadcastable=[False,False]) a = astensor(numpy.array([[1, 2, 3], [4, 5, 6]]), broadcastable=[False,False])
...@@ -1448,3 +1475,4 @@ class T_stdlib(unittest.TestCase): ...@@ -1448,3 +1475,4 @@ class T_stdlib(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -581,7 +581,48 @@ def transpose(x, **kwargs): ...@@ -581,7 +581,48 @@ def transpose(x, **kwargs):
class Subtensor_dx(Op, Viewer): class Subtensor_dx(Op, Viewer):
"""Return a tensor full of zeros, except for what was sliced from x by """Return a tensor full of zeros, except for what was sliced from x by
Subtensor. Subtensor.
@todo: pass the shape of x, rather than x itself.
@todo: add support for advanced tensor indexing (breaks current perform
implementation).
""" """
def __init__(self, inputs, idx_list, **kwargs):
Op.__init__(self, **kwargs)
self.inputs = inputs
self.outputs = [Tensor(inputs[0].dtype, inputs[0].broadcastable)]
self.idx_list = idx_list
def perform(self):
x = self.inputs[0]
gz = self.inputs[-1]
cdata = []
for c in self.idx_list:
if isinstance(c, slice):
cdata.append(slice(
None if c.start is None else self.inputs[c.start].data,
None if c.stop is None else self.inputs[c.stop].data,
None if c.step is None else self.inputs[c.step].data))
else:
d = self.inputs[c].data
assert 'int' in str(d.dtype)
cdata.append(d)
if len(cdata) > 1:
cdata = tuple(cdata) #there's a diff between tuple and list here...
else:
cdata = cdata[0]
#print cdata
#print gz.data
gx = numpy.zeros_like(x.data)
gx[cdata] = gz.data
#print gx
self.outputs[0].data = gx
def clone_with_new_inputs(self, *new_inputs):
assert len(self.inputs) == len(new_inputs)
return Subtensor_dx(new_inputs, self.idx_list)
class Subtensor(Op, Viewer): class Subtensor(Op, Viewer):
"""Return a subtensor view """Return a subtensor view
...@@ -593,6 +634,7 @@ class Subtensor(Op, Viewer): ...@@ -593,6 +634,7 @@ class Subtensor(Op, Viewer):
of each slice are also integer indexes into the inputs array (or None). The of each slice are also integer indexes into the inputs array (or None). The
inputs array is the tensor x, followed by scalar integer results. inputs array is the tensor x, followed by scalar integer results.
@todo: add support for advanced tensor indexing (in Subtensor_dx too).
""" """
e_invalid = 'invalid index' e_invalid = 'invalid index'
debug = 0 debug = 0
...@@ -683,7 +725,7 @@ class Subtensor(Op, Viewer): ...@@ -683,7 +725,7 @@ class Subtensor(Op, Viewer):
assert 'int' in str(d.dtype) assert 'int' in str(d.dtype)
cdata.append(d) cdata.append(d)
if len(cdata) > 1: if len(cdata) > 1:
cdata = tuple(cdata) #there's a diff between tuples and lists here... cdata = tuple(cdata) #there's a diff between tuple and list here...
else: else:
cdata = cdata[0] cdata = cdata[0]
...@@ -692,13 +734,8 @@ class Subtensor(Op, Viewer): ...@@ -692,13 +734,8 @@ class Subtensor(Op, Viewer):
print self.inputs[0].data, cdata, self.outputs[0].data print self.inputs[0].data, cdata, self.outputs[0].data
def grad(self, inputs, (gz,)): def grad(self, inputs, (gz,)):
# - option: allocate a potentially large matrix of zeros, and fill in return [Subtensor_dx(self.inputs + [gz], self.idx_list).outputs[0]]\
# the appropriate elements from gz + [None] * (len(inputs)-1)
# - option: return a sparse matrix
# - option: return gz, but think about how to include a special addition
# function that works on a corresponding view of the original data
# - return a Subtensor_dx op, which we will optimize away.
return [Subtensor_dx(gz, inputs[0], *self.new_args)] + [None] * (len(inputs)-1)
def clone_with_new_inputs(self, *new_inputs): def clone_with_new_inputs(self, *new_inputs):
assert len(self.inputs) == len(new_inputs) assert len(self.inputs) == len(new_inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论