提交 fea40244 authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Completed the creation of the new returned Subtensor op

上级 6c0263e4
......@@ -189,12 +189,14 @@ def local_dimshuffle_subtensor(node):
if 'x' in node.op.new_order :
return False
new_order = node.op.new_order
past_dim = new_order[0]
for dim in new_order[1:]:
if not dim > past_dim:
return False
else:
past_dim = dim
# new order could be empty
if len(new_order) > 1:
past_dim = new_order[0]
for dim in new_order[1:]:
if not dim > past_dim:
return False
else:
past_dim = dim
input_ = node.inputs[0]
if isinstance(input_.owner.op, Subtensor):
......@@ -209,14 +211,31 @@ def local_dimshuffle_subtensor(node):
return False
# create a new idx_list for a new Subtensor object
# idx_list could be longer than the len(missing_dims), that would happen with
# x[0, :, :].dimshuffle(1,0)
# have to loop on idx_list and inputs
# inputs has the length of sum of non None elements of idx_list (check in slice!)
# len(missing_dims) can be < len(idx_list), this happens if
# tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list)
offset = len(new_idx_list) - len(missing_dims)
new_inputs = [input_.owner.inputs[0]]
zero = T.constant(0)
for dim in missing_dims:
new_idx_list[dim + offset] = slice(zero, None, None)
input_.owner.op.idx_list = tuple(new_idx_list)
return [input_]
slice_attr_list = ['start','stop','step']
j = 0
import ipdb; ipdb.set_trace()
for idx in input_.owner.op.idx_list:
if isinstance(idx, slice):
past_j = j
for slice_attr in slice_attr_list:
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1+j]]
j += 1
if past_j == j:
# here is a slice(None, None, None), that's where
# we want to index with 0.
new_idx_list[j] = zero
new_inputs += [zero]
else:
new_inputs += [input_.owner.inputs[1+j]]
j += 1
import ipdb; ipdb.set_trace()
return [Subtensor(new_idx_list)(*new_inputs)]
return False
......@@ -178,10 +178,9 @@ def test_local_dimshuffle_subtensor():
x = tensor.patternbroadcast(x, (False, True, False, False))
i = tensor.iscalar('i')
out = x[i, :, 10:30, :-1].dimshuffle(1,2)
out = x[i, :, 10:30, ::-1].dimshuffle(1,2)
g = FunctionGraph([x,i], [out])
import ipdb; ipdb.set_trace()
dimshuffle_subtensor(g)
topo = g.toposort()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论