提交 ca020339 authored 作者: Frederic's avatar Frederic

Make get_scalar_constant_value() infer in Subtensor(Join of vector).

This cause some advanced subtensor to infer False for broadcast instead of True. fix in part gh-2165
上级 419768e7
...@@ -632,22 +632,38 @@ def get_scalar_constant_value(orig_v, elemwise=True): ...@@ -632,22 +632,38 @@ def get_scalar_constant_value(orig_v, elemwise=True):
# test_sharedvar.py:test_shared_options.test_specify_shape_partial # test_sharedvar.py:test_shared_options.test_specify_shape_partial
if (v.owner.inputs[0].owner and if (v.owner.inputs[0].owner and
isinstance(v.owner.inputs[0].owner.op, Join) and isinstance(v.owner.inputs[0].owner.op, Join) and
len(v.owner.op.idx_list) == 1):
# Ensure the Join is joining only scalar variables (so that # Ensure the Join is joining only scalar variables (so that
# the constant value can be found at the same index as the one # the constant value can be found at the same index as the one
# used in the sub-tensor). # used in the sub-tensor).
python_all(var.ndim == 0 for var in if python_all(var.ndim == 0 for var in
v.owner.inputs[0].owner.inputs) and v.owner.inputs[0].owner.inputs[1:]):
len(v.owner.op.idx_list) == 1): idx = v.owner.op.idx_list[0]
if isinstance(idx, gof.Type):
idx = v.owner.op.idx_list[0] idx = get_scalar_constant_value(v.owner.inputs[1])
if isinstance(idx, gof.Type): # Note the '+ 1' is because the first argument to Join is the
idx = get_scalar_constant_value(v.owner.inputs[1]) # axis.
# Note the '+ 1' is because the first argument to Join is the ret = v.owner.inputs[0].owner.inputs[idx + 1]
# axis. ret = get_scalar_constant_value(ret)
ret = v.owner.inputs[0].owner.inputs[idx + 1] # join can cast implicitly its input in some case.
ret = get_scalar_constant_value(ret) return theano._asarray(ret, dtype=v.type.dtype)
# join can cast implicitly its input in some case. if python_all(var.ndim == 1 for var in
return theano._asarray(ret, dtype=v.type.dtype) v.owner.inputs[0].owner.inputs[1:]):
idx = v.owner.op.idx_list[0]
if isinstance(idx, gof.Type):
idx = get_scalar_constant_value(v.owner.inputs[1])
try:
#TODO: assert joined axis is 0.
length = 0
for joined in v.owner.inputs[0].owner.inputs[1:]:
ll = get_vector_length(joined)
if idx < length + ll:
return get_scalar_constant_value(joined[idx-length])
length += ll
except TypeError:
pass
except ValueError:
pass
elif (v.owner.inputs[0].owner and elif (v.owner.inputs[0].owner and
isinstance(v.owner.inputs[0].owner.op, isinstance(v.owner.inputs[0].owner.op,
...@@ -3663,6 +3679,15 @@ def get_vector_length(v): ...@@ -3663,6 +3679,15 @@ def get_vector_length(v):
return len(v.owner.inputs) return len(v.owner.inputs)
if v.owner and isinstance(v.owner.op, Shape): if v.owner and isinstance(v.owner.op, Shape):
return v.owner.inputs[0].type.ndim return v.owner.inputs[0].type.ndim
# If we take this slice: var[:0], we know it will have 0 elements.
if (v.owner and
isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and
isinstance(v.owner.op.idx_list[0], slice) and
v.owner.op.idx_list[0].start in [None, 0]):
stop = theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list)[0].stop
if extract_constant(stop) == 0:
return 0
raise ValueError("length not known") raise ValueError("length not known")
......
...@@ -2095,7 +2095,12 @@ def take(a, indices, axis=None, mode='raise'): ...@@ -2095,7 +2095,12 @@ def take(a, indices, axis=None, mode='raise'):
shape = indices.shape shape = indices.shape
ndim = indices.ndim ndim = indices.ndim
else: else:
shape = theano.tensor.concatenate( # If axis is 0, don't generate a useless concatenation.
[a.shape[:axis], indices.shape, a.shape[axis + 1:]]) if axis == 0:
shape = theano.tensor.concatenate(
[indices.shape, a.shape[axis + 1:]])
else:
shape = theano.tensor.concatenate(
[a.shape[:axis], indices.shape, a.shape[axis + 1:]])
ndim = a.ndim + indices.ndim - 1 ndim = a.ndim + indices.ndim - 1
return take(a, indices.flatten(), axis, mode).reshape(shape, ndim) return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)
...@@ -27,7 +27,7 @@ from theano.tensor.subtensor import (inc_subtensor, set_subtensor, ...@@ -27,7 +27,7 @@ from theano.tensor.subtensor import (inc_subtensor, set_subtensor,
from theano.tensor import (as_tensor_variable, _shared, from theano.tensor import (as_tensor_variable, _shared,
NotScalarConstantError, NotScalarConstantError,
fscalar, iscalar, dscalar, cscalar, fscalar, iscalar, dscalar, cscalar,
vector, dvector, fvector, lvector, vector, dvector, fvector, lvector, lrow,
fmatrix, dmatrix, lmatrix, matrix, fmatrix, dmatrix, lmatrix, matrix,
ctensor3, dtensor4) ctensor3, dtensor4)
from theano.tensor.tests.test_basic import rand, randint_ranged, inplace_func from theano.tensor.tests.test_basic import rand, randint_ranged, inplace_func
...@@ -1140,6 +1140,7 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -1140,6 +1140,7 @@ class TestAdvancedSubtensor(unittest.TestCase):
self.ix1 = lvector() # advanced 1d query self.ix1 = lvector() # advanced 1d query
self.ix12 = lvector() self.ix12 = lvector()
self.ix2 = lmatrix() self.ix2 = lmatrix()
self.ixr = lrow()
def eval_output_and_check(self, t): def eval_output_and_check(self, t):
f = inplace_func([], t, mode=self.mode) f = inplace_func([], t, mode=self.mode)
...@@ -1164,6 +1165,11 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -1164,6 +1165,11 @@ class TestAdvancedSubtensor(unittest.TestCase):
assert a.broadcastable == self.ix2.broadcastable, ( assert a.broadcastable == self.ix2.broadcastable, (
a.broadcastable, self.ix2.broadcastable) a.broadcastable, self.ix2.broadcastable)
def test_index_into_mat_w_row(self):
a = self.m[self.ixr]
assert a.dtype == self.m.dtype, (a.dtype, self.m.dtype)
assert a.broadcastable == (True, False, False)
def test_index_w_int_and_vec(self): def test_index_w_int_and_vec(self):
# like test_ok_list, but with a single index on the first one # like test_ok_list, but with a single index on the first one
# data has to have at least 2 dimensions # data has to have at least 2 dimensions
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论