提交 5259e427 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5456 from affanv14/indexwork

Implement indexing with empty list
...@@ -26,6 +26,19 @@ def test_numpy_method(): ...@@ -26,6 +26,19 @@ def test_numpy_method():
np.nan_to_num(fct(data))) np.nan_to_num(fct(data)))
def test_empty_list_indexing():
ynp = np.zeros((2, 2))[:, []]
znp = np.zeros((2, 2))[:, ()]
data = [[0, 0], [0, 0]]
x = tt.dmatrix('x')
y = x[:, []]
z = x[:, ()]
fy = theano.function([x], y)
fz = theano.function([x], z)
assert_equal(fy(data).shape, ynp.shape)
assert_equal(fz(data).shape, znp.shape)
def test_copy(): def test_copy():
x = tt.dmatrix('x') x = tt.dmatrix('x')
data = np.random.rand(5, 5) data = np.random.rand(5, 5)
......
...@@ -501,6 +501,10 @@ class _tensor_py_operators(object): ...@@ -501,6 +501,10 @@ class _tensor_py_operators(object):
args[ellipsis_at: ellipsis_at + 1] = ( args[ellipsis_at: ellipsis_at + 1] = (
[slice(None)] * (self.ndim - (len(args) - 1 - new_axes))) [slice(None)] * (self.ndim - (len(args) - 1 - new_axes)))
# Force input to be int64 datatype if input is an empty list or tuple
# Else leave it as is if it is a real number
args = tuple([numpy.array(inp, dtype=numpy.int64)
if(inp == [] or inp == ()) else inp for inp in args])
# Convert python literals to theano constants # Convert python literals to theano constants
args = theano.tensor.subtensor.make_constant(args) args = theano.tensor.subtensor.make_constant(args)
# Determine if advanced indexing is needed or not # Determine if advanced indexing is needed or not
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论