提交 6f2abbeb authored 作者: Frederic's avatar Frederic

Good implementation of shape_i() optimized function

上级 565fd3eb
......@@ -406,11 +406,21 @@ def shape_i(var, i, fgraph=None):
if fgraph is None and hasattr(var, 'fgraph'):
fgraph = var.fgraph
if fgraph and hasattr(fgraph, 'shape_feature'):
if var not in fgraph.shape_feature.shape_of:
# If var isn't in the ShapeFeature, add it.
fgraph.shape_feature.on_import(fgraph, var.owner,
'gof.ops.shape_i')
return fgraph.shape_feature.shape_of[var][i]
shape_feature = fgraph.shape_feature
shape_of = shape_feature.shape_of
def recur(node):
if not hasattr(node.outputs[0], 'fgraph'):
for inp in node.inputs:
if inp.owner:
recur(inp.owner)
# If the output var isn't marked as being in the graph,
# we need to att it in the ShapeFeature.
shape_feature.on_import(fgraph, node,
'gof.ops.shape_i')
if var not in shape_of:
recur(var.owner)
return shape_of[var][i]
# If we are not able to use the shape feature, we should not put
# Shape_i in the graph. Otherwise, the shape feature optimization
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论