提交 ad810a3f authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed the merge optimization for subtensors to cover all possible

combinations of two subtensors.
上级 691bfddf
...@@ -1158,140 +1158,96 @@ def local_subtensor_lift(node): ...@@ -1158,140 +1158,96 @@ def local_subtensor_lift(node):
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_subtensor_merge(node): def local_subtensor_merge(node):
""" """
1) var[int:][-1] -> var[-1] # a little different for when the first subtensor is empty. Refractored optimization to deal with all cases of tensor merging.
2) var[::-1][int] -> var[-int-1] Given a subgraph of the form Subtensor(Subtensor(u)), the optimization
3) var[::-1][:int] -> var[:-int-1:-1] expresses all slices in a canonical form, and then merges them together.
4) var[int1::][:int2] -> """
var[int1:int2 + switch(int2<0,
0,
switch(int1>=0,
int1,
maximum(u.owner.inputs[0].shape[0]+int1,
0))]
"""
if (isinstance(node.op, T.Subtensor) and
len(node.op.idx_list)==1):
if isinstance(node.op, T.Subtensor):
u = node.inputs[0] u = node.inputs[0]
if (not u.owner or len(u.clients) > 1 or if u.owner and isinstance( u.owner.op, T.Subtensor):
not isinstance(u.owner.op, T.Subtensor)): # We can merge :)
return False # x actual tensor on which we are picking slices
x = u.owner.inputs[0]
# var[int:][-1] -> var[-1] # slices of the first applied subtensor
if (len(node.inputs)==1 and sl1 = T.get_idx_list(u.owner.inputs, u.owner.op.idx_list)
node.op.idx_list[0]==-1 and sl2 = T.get_idx_list(node.inputs , node.op.idx_list )
len(u.owner.op.idx_list)==1 and # Get the shapes of the vectors !
isinstance(u.owner.op.idx_list[0], slice) and try:
u.owner.op.idx_list[0].stop is None and # try not to introduce new shape into the graph
u.owner.op.idx_list[0].step is None xshape = node.env.shape_feature.shape_of[x]
): ushape = node.env.shape_feature.shape_of[u]
u_start = u.owner.op.idx_list[0].start except:
xhsape = x.shape
if len(u.owner.inputs) == 1 and isinstance(u_start, int): ushape = u.shape
start0 = T.as_tensor_variable(u_start)
elif (len(u.owner.inputs) == 2 and # convert each list of slices into canonical forms
isinstance (u_start, scalar.basic.Scalar)): cnf1 = [ T.get_canonical_form_slice(x,xshape[i]) for (i,x) in
start0 = T.tensor_from_scalar(u.owner.inputs[1]) enumerate(sl1) ]
else: cnf2 = [ T.get_canonical_form_slice(x,ushape[i]) for (i,x) in
return False enumerate(sl2) ]
len0 = u.owner.inputs[0].shape[0] # Some helpful utility functions :
# The following is equivalent to: def safe_prod(x,y):
# if start0 <= -u.shape[0]: if x is None:
# actual_start0 = 0 return y
# elif start0 < 0: if y is None:
# actual_start0 = start0 + u.shape[0] return x
# else: return x*y
# actual_start0 = start0
actual_start0 = (start0 > -len0) * (start0 + ((start0 < 0) * len0)) merged_cnf = []
pos_cnf2 = 0
# if actual_start < u.shape[0]: for idx,(sl, reverse) in enumerate(cnf1):
# new_index = -1 if type(sl) is not slice:
# else: # Will give an IndexError merged_cnf += [ (sl, reverse) ]
# new_index = actual_start elif type(cnf2[pos_cnf2][0]) is not slice:
new_index = -1 + (actual_start0 >= len0) * (actual_start0 + 1) xlen = xshape[idx]
ulen = ushape[idx]
new_index = T.scalar_from_tensor(new_index) udx = cnf2[pos_cnf2][0]
return [u.owner.inputs[0][new_index]] if reverse is None:
# we need to check if things are fine
# var[::-1][int] -> var[-int-1] val = sl.start + udx
if (len(node.inputs) in [1,2] and val = T.switch(T.lt(udx,0), xlen+1, val)
isinstance(node.op.idx_list[0], (int, scalar.basic.Scalar)) and val = T.switch(T.ge(udx,ulen), xlen+1, val)
len(u.owner.op.idx_list)==1 and merged_cnf += [ (val,None) ]
isinstance(u.owner.op.idx_list[0], slice) and pos_cnf2 += 1
u.owner.op.idx_list[0].start is None and else:
u.owner.op.idx_list[0].stop is None and p_val = sl.start + cnf2[pos_cnf2][0]
u.owner.op.idx_list[0].step == -1 n_val = sl.stop - sl.start - 1 - cnf2[pos_cnf2][0]
): val = T.switch(T.lt(reverse,0), n_val, p_val)
idx = node.op.idx_list[0] val = T.switch(T.lt(udx,0), xlen+1, val)
if len(node.inputs) == 1 and isinstance(idx, int): val = T.switch(T.ge(udx,ulen), xlen+1, val)
idx = T.as_tensor_variable(idx) merged_cnf += [(val, None)]
elif (len(node.inputs) == 2 and pos_cnf2 += 1
isinstance (idx, scalar.basic.Scalar)): else:
idx = T.tensor_from_scalar(node.inputs[1]) start = sl.start + cnf2[pos_cnf2][0].start
else: stop = sl.start + cnf2[pos_cnf2][0].stop
return False step = sl.step * cnf2[pos_cnf2][0].step
merged_reverse = safe_prod(reverse, cnf2[pos_cnf2][1])
return [u.owner.inputs[0][-idx-1]] pos_cnf2 += 1
merged_cnf += [(slice(start, stop, step),
# var[::-1][:int] -> var[:-int-1:-1] merged_reverse)]
if (len(node.inputs) in [1,2] and
len(u.owner.op.idx_list)==1 and merged_cnf += cnf2[pos_cnf2:]
isinstance(node.op.idx_list[0], slice) and result_slices = []
node.op.idx_list[0].start in [0, None] and # We need to apply the reverse flag where needed
isinstance(node.op.idx_list[0].stop, (int, scalar.basic.Scalar)) and __pos = 0
node.op.idx_list[0].step is None and for cnf, reverse in merged_cnf:
isinstance(u.owner.op.idx_list[0], slice) and __pos +=1
u.owner.op.idx_list[0].start is None and if reverse is not None:
u.owner.op.idx_list[0].stop is None and start = T.switch(T.lt(reverse,0), cnf.stop-1, cnf.start)
u.owner.op.idx_list[0].step == -1 stop = T.switch(T.lt(reverse,0), cnf.start-1, cnf.stop)
): result_slices += [slice(start,stop,cnf.step*reverse)]
slice_idx = node.op.idx_list[0] else:
idx = slice_idx.stop result_slices += [cnf]
if len(node.inputs) == 1 and isinstance(idx, int):
idx = T.as_tensor_variable(idx) subtens = T.Subtensor(result_slices)
elif (len(node.inputs) == 2 and sl_ins = T.Subtensor.collapse(
isinstance (idx, scalar.basic.Scalar)): result_slices
idx = T.tensor_from_scalar(node.inputs[1]) , lambda x: isinstance(x, T.Variable))
else: out = subtens.make_node(node.inputs[0], *sl_ins).outputs[0]
return False return [subtens.make_node(u.owner.inputs[0], *sl_ins).outputs[0]]
return [u.owner.inputs[0][:-idx-1:-1]]
# var[int1::][:int2]
if (len(node.inputs) in [1, 2] and
isinstance(node.op.idx_list[0], slice) and
node.op.idx_list[0].start in [0, None] and
isinstance(node.op.idx_list[0].stop,(int, scalar.basic.Scalar)) and
node.op.idx_list[0].step is None and
len(u.owner.op.idx_list)==1 and
isinstance(u.owner.op.idx_list[0], slice) and
isinstance(u.owner.op.idx_list[0].start,(int, scalar.basic.Scalar)) and
u.owner.op.idx_list[0].stop in [sys.maxint, None] and
u.owner.op.idx_list[0].step is None
):
idx1 = u.owner.op.idx_list[0].start
idx2 = node.op.idx_list[0].stop
if isinstance(idx1, scalar.basic.Scalar):
idx1 = T.tensor_from_scalar(u.owner.inputs[1])
elif isinstance(idx1, int):
idx1 = T.as_tensor_variable(idx1)
if isinstance(idx2, scalar.basic.Scalar):
idx2 = T.tensor_from_scalar(node.inputs[1])
elif isinstance(idx2, int):
idx2 = T.as_tensor_variable(idx2)
# Get positive version of idx1
# TODO: use Razvan's code for that
# The maximum is needed so that shape[0] + idx1 >= 0
neg_idx1 = T.maximum(u.owner.inputs[0].shape[0]+idx1, 0)
new_idx1 = T.switch((idx1 >= 0), idx1, neg_idx1)
# If idx2<0, we are indexing from the end, so idx2 is OK
# If we are indexing from the beginning, we need to add pos_idx1
new_idx2 = idx2 + T.switch((idx2 < 0), 0, new_idx1)
return [u.owner.inputs[0][new_idx1:new_idx2]]
@register_canonicalize @register_canonicalize
@gof.local_optimizer([None]) @gof.local_optimizer([None])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论