提交 01cda102 authored 作者: James Bergstra's avatar James Bergstra

fixed get_vector_length bug, added support for unpacking vector results

上级 95c4324b
...@@ -105,6 +105,12 @@ def as_ndarray_result(x, name = None, ndim=None): ...@@ -105,6 +105,12 @@ def as_ndarray_result(x, name = None, ndim=None):
return shape_padleft(x, n_ones=(ndim - x.type.ndim)) return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else: else:
return x return x
if isinstance(x, (tuple, list)) and any(isinstance(xi, Result) for xi in x):
try:
return stack(*x)
except (TypeError, ValueError):
pass
try: try:
return constant(x, name=name, ndim=ndim) return constant(x, name=name, ndim=ndim)
except TypeError: except TypeError:
...@@ -597,6 +603,10 @@ class _tensor_py_operators: ...@@ -597,6 +603,10 @@ class _tensor_py_operators:
def copy(self): return tensor_copy(self) def copy(self): return tensor_copy(self)
def __iter__(self): def __iter__(self):
try:
for i in xrange(get_vector_length(self)):
yield self[i]
except:
# This prevents accidental iteration via builtin.sum(self) # This prevents accidental iteration via builtin.sum(self)
raise TypeError('NDArrayType does not support iteration. ' raise TypeError('NDArrayType does not support iteration. '
'Maybe you are using builtin.sum instead of theano.tensor.sum? (Maybe .max?)') 'Maybe you are using builtin.sum instead of theano.tensor.sum? (Maybe .max?)')
...@@ -1690,8 +1700,8 @@ class Join(Op): ...@@ -1690,8 +1700,8 @@ class Join(Op):
if node.ndim != 1: if node.ndim != 1:
raise TypeError('argument must be symbolic vector') raise TypeError('argument must be symbolic vector')
inputs = node.owner.inputs inputs = node.owner.inputs
axis, tensors = inputs[0], inputs[1] axis, tensors = inputs[0], inputs[1:]
# if v is a vector, axis must be 0 # if v is a vector, then axis must be 0
# the question is whether all the inputs are broadcastable. # the question is whether all the inputs are broadcastable.
if all(i.broadcastable[0] for i in tensors): if all(i.broadcastable[0] for i in tensors):
return len(tensors) return len(tensors)
...@@ -1787,8 +1797,7 @@ def get_vector_length(v): ...@@ -1787,8 +1797,7 @@ def get_vector_length(v):
cases. cases.
""" """
if not isinstance(v, gof.Result): v = as_ndarray_result(v)
v = constant(v)
if v.ndim != 1: if v.ndim != 1:
raise TypeError('argument must be symbolic vector') raise TypeError('argument must be symbolic vector')
if isinstance(v, gof.Constant) and v.type.ndim == 1: if isinstance(v, gof.Constant) and v.type.ndim == 1:
...@@ -1923,13 +1932,10 @@ class Reshape(Op): ...@@ -1923,13 +1932,10 @@ class Reshape(Op):
return [reshape(g_out, shape(x), ndim=x.ndim), None] return [reshape(g_out, shape(x), ndim=x.ndim), None]
def reshape(x, newshape, ndim=None): def reshape(x, newshape, ndim=None):
if not hasattr(reshape, 'op'):
reshape.op = {}
if ndim is None: if ndim is None:
ndim = get_vector_length(newshape) ndim = get_vector_length(newshape)
if ndim not in reshape.op: op = Reshape(ndim)
reshape.op[ndim] = Reshape(ndim) return op(x, newshape)
return reshape.op[ndim](x, newshape)
class Flatten(Op): class Flatten(Op):
...@@ -2233,13 +2239,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=[]): ...@@ -2233,13 +2239,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=[]):
NDArrayType(dtype = p.type.dtype, broadcastable = []), NDArrayType(dtype = p.type.dtype, broadcastable = []),
numpy.asarray(0, dtype=p.type.dtype)) numpy.asarray(0, dtype=p.type.dtype))
try: #try:
it = iter(wrt) #it = iter(wrt)
except: #except:
it = None #it = None
if it: #hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)): #if it: #hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in it] if isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in wrt]
else: else:
return gmap.get(wrt, zero(wrt)) return gmap.get(wrt, zero(wrt))
...@@ -2348,8 +2355,6 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0 ...@@ -2348,8 +2355,6 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0
tensor_pt = [value(p.copy(), name='input %i'%i) for i,p in enumerate(pt)] tensor_pt = [value(p.copy(), name='input %i'%i) for i,p in enumerate(pt)]
#op can be either a function or an actual Op instance #op can be either a function or an actual Op instance
#print "OP", op
#print "TENSOR PT", tensor_pt
o_output = op(*tensor_pt) o_output = op(*tensor_pt)
if isinstance(o_output,list) > 1: if isinstance(o_output,list) > 1:
...@@ -2358,9 +2363,7 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0 ...@@ -2358,9 +2363,7 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0
# but this doesn't handle the case where not all the outputs are # but this doesn't handle the case where not all the outputs are
# differentiable... so I leave this as TODO for now -JB. # differentiable... so I leave this as TODO for now -JB.
o_fn = function(tensor_pt, o_output) o_fn = function(tensor_pt, o_output)
#print "PT B", pt
o_fn_out = o_fn(*[p.copy() for p in pt]) o_fn_out = o_fn(*[p.copy() for p in pt])
#print "PT C", pt
random_projection = rng.rand(*o_fn_out.shape) random_projection = rng.rand(*o_fn_out.shape)
t_r = as_ndarray_result(random_projection) t_r = as_ndarray_result(random_projection)
...@@ -2372,17 +2375,10 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0 ...@@ -2372,17 +2375,10 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0
symbolic_grad = grad(cost, tensor_pt,as_ndarray_result(1.0,name='g_cost')) symbolic_grad = grad(cost, tensor_pt,as_ndarray_result(1.0,name='g_cost'))
if 0:
print '----------'
for op in gof.graph.io_toposort(tensor_pt, symbolic_grad):
print op
grad_fn = function(tensor_pt, symbolic_grad) grad_fn = function(tensor_pt, symbolic_grad)
#print "PT D", pt
analytic_grad = grad_fn(*pt) analytic_grad = grad_fn(*pt)
#print "PT Z", pt
if not isinstance(analytic_grad, (list, tuple)): if not isinstance(analytic_grad, (list, tuple)):
analytic_grad = [analytic_grad] analytic_grad = [analytic_grad]
......
...@@ -897,7 +897,6 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -897,7 +897,6 @@ class T_Join_and_Split(unittest.TestCase):
want = numpy.array([1, 2, 3]) want = numpy.array([1, 2, 3])
self.failUnless((eval_outputs([s]) == want).all()) self.failUnless((eval_outputs([s]) == want).all())
def test_join_vector(self): def test_join_vector(self):
a = as_ndarray_result(numpy.array([1, 2, 3])) a = as_ndarray_result(numpy.array([1, 2, 3]))
b = as_ndarray_result(numpy.array([7, 8, 9])) b = as_ndarray_result(numpy.array([7, 8, 9]))
...@@ -976,6 +975,16 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -976,6 +975,16 @@ class T_Join_and_Split(unittest.TestCase):
verify_grad(self, lambda a, b: join(0,a,b), [v, 2*v]) verify_grad(self, lambda a, b: join(0,a,b), [v, 2*v])
verify_grad(self, lambda a, b: join(1,a,b), [v, 2*v]) verify_grad(self, lambda a, b: join(1,a,b), [v, 2*v])
def test_vector_len(self):
x = lscalar('x')
y = dscalar('y')
triple = as_ndarray_result((x, y, 9.0))
assert 3 == get_vector_length(triple)
a,b,c = triple
f = function([x,y], [b,c,a])
assert numpy.allclose(f(4, 5), [5, 9, 4])
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
def test_gt(self): def test_gt(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论