提交 01cda102 authored 作者: James Bergstra's avatar James Bergstra

fixed get_vector_length bug, added support for unpacking vector results

上级 95c4324b
......@@ -105,6 +105,12 @@ def as_ndarray_result(x, name = None, ndim=None):
return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else:
return x
if isinstance(x, (tuple, list)) and any(isinstance(xi, Result) for xi in x):
try:
return stack(*x)
except (TypeError, ValueError):
pass
try:
return constant(x, name=name, ndim=ndim)
except TypeError:
......@@ -597,9 +603,13 @@ class _tensor_py_operators:
def copy(self): return tensor_copy(self)
def __iter__(self):
# This prevents accidental iteration via builtin.sum(self)
raise TypeError('NDArrayType does not support iteration. '
'Maybe you are using builtin.sum instead of theano.tensor.sum? (Maybe .max?)')
try:
for i in xrange(get_vector_length(self)):
yield self[i]
except:
# This prevents accidental iteration via builtin.sum(self)
raise TypeError('NDArrayType does not support iteration. '
'Maybe you are using builtin.sum instead of theano.tensor.sum? (Maybe .max?)')
# CONVENIENT ACCESS TO TYPE PROPERTIES
......@@ -1690,8 +1700,8 @@ class Join(Op):
if node.ndim != 1:
raise TypeError('argument must be symbolic vector')
inputs = node.owner.inputs
axis, tensors = inputs[0], inputs[1]
# if v is a vector, axis must be 0
axis, tensors = inputs[0], inputs[1:]
# if v is a vector, then axis must be 0
# the question is whether all the inputs are broadcastable.
if all(i.broadcastable[0] for i in tensors):
return len(tensors)
......@@ -1787,8 +1797,7 @@ def get_vector_length(v):
cases.
"""
if not isinstance(v, gof.Result):
v = constant(v)
v = as_ndarray_result(v)
if v.ndim != 1:
raise TypeError('argument must be symbolic vector')
if isinstance(v, gof.Constant) and v.type.ndim == 1:
......@@ -1923,13 +1932,10 @@ class Reshape(Op):
return [reshape(g_out, shape(x), ndim=x.ndim), None]
def reshape(x, newshape, ndim=None):
if not hasattr(reshape, 'op'):
reshape.op = {}
if ndim is None:
ndim = get_vector_length(newshape)
if ndim not in reshape.op:
reshape.op[ndim] = Reshape(ndim)
return reshape.op[ndim](x, newshape)
op = Reshape(ndim)
return op(x, newshape)
class Flatten(Op):
......@@ -2233,13 +2239,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=[]):
NDArrayType(dtype = p.type.dtype, broadcastable = []),
numpy.asarray(0, dtype=p.type.dtype))
try:
it = iter(wrt)
except:
it = None
#try:
#it = iter(wrt)
#except:
#it = None
if it: #hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in it]
#if it: #hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)):
if isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in wrt]
else:
return gmap.get(wrt, zero(wrt))
......@@ -2348,8 +2355,6 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0
tensor_pt = [value(p.copy(), name='input %i'%i) for i,p in enumerate(pt)]
#op can be either a function or an actual Op instance
#print "OP", op
#print "TENSOR PT", tensor_pt
o_output = op(*tensor_pt)
if isinstance(o_output,list) > 1:
......@@ -2358,9 +2363,7 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0
# but this doesn't handle the case where not all the outputs are
# differentiable... so I leave this as TODO for now -JB.
o_fn = function(tensor_pt, o_output)
#print "PT B", pt
o_fn_out = o_fn(*[p.copy() for p in pt])
#print "PT C", pt
random_projection = rng.rand(*o_fn_out.shape)
t_r = as_ndarray_result(random_projection)
......@@ -2372,17 +2375,10 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0
symbolic_grad = grad(cost, tensor_pt,as_ndarray_result(1.0,name='g_cost'))
if 0:
print '----------'
for op in gof.graph.io_toposort(tensor_pt, symbolic_grad):
print op
grad_fn = function(tensor_pt, symbolic_grad)
#print "PT D", pt
analytic_grad = grad_fn(*pt)
#print "PT Z", pt
if not isinstance(analytic_grad, (list, tuple)):
analytic_grad = [analytic_grad]
......
......@@ -897,7 +897,6 @@ class T_Join_and_Split(unittest.TestCase):
want = numpy.array([1, 2, 3])
self.failUnless((eval_outputs([s]) == want).all())
def test_join_vector(self):
a = as_ndarray_result(numpy.array([1, 2, 3]))
b = as_ndarray_result(numpy.array([7, 8, 9]))
......@@ -976,6 +975,16 @@ class T_Join_and_Split(unittest.TestCase):
verify_grad(self, lambda a, b: join(0,a,b), [v, 2*v])
verify_grad(self, lambda a, b: join(1,a,b), [v, 2*v])
def test_vector_len(self):
x = lscalar('x')
y = dscalar('y')
triple = as_ndarray_result((x, y, 9.0))
assert 3 == get_vector_length(triple)
a,b,c = triple
f = function([x,y], [b,c,a])
assert numpy.allclose(f(4, 5), [5, 9, 4])
class test_comparison(unittest.TestCase):
def test_gt(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论