提交 64efe269 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Updated the subtensor - subtensor merge optimization to make use of the

previously introduced function for merging two slices.
上级 7a45b553
...@@ -1280,84 +1280,39 @@ def local_subtensor_merge(node): ...@@ -1280,84 +1280,39 @@ def local_subtensor_merge(node):
# x actual tensor on which we are picking slices # x actual tensor on which we are picking slices
x = u.owner.inputs[0] x = u.owner.inputs[0]
# slices of the first applied subtensor # slices of the first applied subtensor
sl1 = T.get_idx_list(u.owner.inputs, u.owner.op.idx_list) slices1 = T.get_idx_list(u.owner.inputs, u.owner.op.idx_list)
sl2 = T.get_idx_list(node.inputs , node.op.idx_list ) slices2 = T.get_idx_list(node.inputs , node.op.idx_list )
# Get the shapes of the vectors ! # Get the shapes of the vectors !
try: try:
# try not to introduce new shape into the graph # try not to introduce new shape into the graph
xshape = node.env.shape_feature.shape_of[x] xshape = node.env.shape_feature.shape_of[x]
ushape = node.env.shape_feature.shape_of[u] ushape = node.env.shape_feature.shape_of[u]
except: except AttributeError:
# Following the suggested use of shape_feature which should
# consider the case when the compilation mode doesn't
# include the ShapeFeature
xhsape = x.shape xhsape = x.shape
ushape = u.shape ushape = u.shape
# convert each list of slices into canonical forms merged_slices = []
cnf1 = [ T.get_canonical_form_slice(x,xshape[i]) for (i,x) in pos_2 = 0
enumerate(sl1) ] for pos_1, slice1 in enumerate(slices1):
cnf2 = [ T.get_canonical_form_slice(x,ushape[i]) for (i,x) in if type(slice1) is slice:
enumerate(sl2) ] merged_slices.append(
merge_two_slices(slice1,
# Some helpful utility functions : slices2[pos_2],
def safe_prod(x,y): xshape[pos_1],
if x is None: ushape[pos_2]))
return y pos_2 += 1
if y is None:
return x
return x*y
merged_cnf = []
pos_cnf2 = 0
for idx,(sl, reverse) in enumerate(cnf1):
if type(sl) is not slice:
merged_cnf += [ (sl, reverse) ]
elif type(cnf2[pos_cnf2][0]) is not slice:
xlen = xshape[idx]
ulen = ushape[idx]
udx = cnf2[pos_cnf2][0]
if reverse is None:
# we need to check if things are fine
val = sl.start + udx
val = T.switch(T.lt(udx,0), xlen+1, val)
val = T.switch(T.ge(udx,ulen), xlen+1, val)
merged_cnf += [ (val,None) ]
pos_cnf2 += 1
else:
p_val = sl.start + cnf2[pos_cnf2][0]
n_val = sl.stop - sl.start - 1 - cnf2[pos_cnf2][0]
val = T.switch(T.lt(reverse,0), n_val, p_val)
val = T.switch(T.lt(udx,0), xlen+1, val)
val = T.switch(T.ge(udx,ulen), xlen+1, val)
merged_cnf += [(val, None)]
pos_cnf2 += 1
else: else:
start = sl.start + cnf2[pos_cnf2][0].start merged_slices.append(slice1)
stop = sl.start + cnf2[pos_cnf2][0].stop
step = sl.step * cnf2[pos_cnf2][0].step merged_slices += slices2[pos_2:]
merged_reverse = safe_prod(reverse, cnf2[pos_cnf2][1]) subtens = T.Subtensor(merged_slices)
pos_cnf2 += 1 sl_ins = T.Subtensor.collapse(
merged_cnf += [(slice(start, stop, step), merged_slices,
merged_reverse)] lambda x: isinstance(x, T.Variable))
return [ subtens.make_node(node.inputs[0], *sl_ins).outputs[0]]
merged_cnf += cnf2[pos_cnf2:]
result_slices = []
# We need to apply the reverse flag where needed
__pos = 0
for cnf, reverse in merged_cnf:
__pos +=1
if reverse is not None:
start = T.switch(T.lt(reverse,0), cnf.stop-1, cnf.start)
stop = T.switch(T.lt(reverse,0), cnf.start-1, cnf.stop)
result_slices += [slice(start,stop,cnf.step*reverse)]
else:
result_slices += [cnf]
subtens = T.Subtensor(result_slices)
sl_ins = T.Subtensor.collapse(
result_slices
, lambda x: isinstance(x, T.Variable))
out = subtens.make_node(node.inputs[0], *sl_ins).outputs[0]
return [subtens.make_node(u.owner.inputs[0], *sl_ins).outputs[0]]
@register_canonicalize @register_canonicalize
@gof.local_optimizer([None]) @gof.local_optimizer([None])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论