提交 fea40244 authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Completed the creation of the new returned Subtensor op

上级 6c0263e4
...@@ -189,6 +189,8 @@ def local_dimshuffle_subtensor(node): ...@@ -189,6 +189,8 @@ def local_dimshuffle_subtensor(node):
if 'x' in node.op.new_order : if 'x' in node.op.new_order :
return False return False
new_order = node.op.new_order new_order = node.op.new_order
# new order could be empty
if len(new_order) > 1:
past_dim = new_order[0] past_dim = new_order[0]
for dim in new_order[1:]: for dim in new_order[1:]:
if not dim > past_dim: if not dim > past_dim:
...@@ -209,14 +211,31 @@ def local_dimshuffle_subtensor(node): ...@@ -209,14 +211,31 @@ def local_dimshuffle_subtensor(node):
return False return False
# create a new idx_list for a new Subtensor object # create a new idx_list for a new Subtensor object
# idx_list could be longer than the len(missing_dims), that would happen with # have to loop on idx_list and inputs
# x[0, :, :].dimshuffle(1,0) # inputs has the length of sum of non None elements of idx_list (check in slice!)
# len(missing_dims) can be < len(idx_list), this happens if
# tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list) new_idx_list = list(input_.owner.op.idx_list)
offset = len(new_idx_list) - len(missing_dims) new_inputs = [input_.owner.inputs[0]]
zero = T.constant(0) zero = T.constant(0)
for dim in missing_dims: slice_attr_list = ['start','stop','step']
new_idx_list[dim + offset] = slice(zero, None, None) j = 0
import ipdb; ipdb.set_trace()
input_.owner.op.idx_list = tuple(new_idx_list) for idx in input_.owner.op.idx_list:
return [input_] if isinstance(idx, slice):
past_j = j
for slice_attr in slice_attr_list:
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1+j]]
j += 1
if past_j == j:
# here is a slice(None, None, None), that's where
# we want to index with 0.
new_idx_list[j] = zero
new_inputs += [zero]
else:
new_inputs += [input_.owner.inputs[1+j]]
j += 1
import ipdb; ipdb.set_trace()
return [Subtensor(new_idx_list)(*new_inputs)]
return False return False
...@@ -178,10 +178,9 @@ def test_local_dimshuffle_subtensor(): ...@@ -178,10 +178,9 @@ def test_local_dimshuffle_subtensor():
x = tensor.patternbroadcast(x, (False, True, False, False)) x = tensor.patternbroadcast(x, (False, True, False, False))
i = tensor.iscalar('i') i = tensor.iscalar('i')
out = x[i, :, 10:30, :-1].dimshuffle(1,2) out = x[i, :, 10:30, ::-1].dimshuffle(1,2)
g = FunctionGraph([x,i], [out]) g = FunctionGraph([x,i], [out])
import ipdb; ipdb.set_trace()
dimshuffle_subtensor(g) dimshuffle_subtensor(g)
topo = g.toposort() topo = g.toposort()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论