提交 b3084a31 authored 作者: Frederic's avatar Frederic

Faster opt local_fill_cut

上级 ddfb387c
......@@ -2559,12 +2559,12 @@ def local_fill_cut(node):
# scalars, but we can't ignore the large matrix because it gives
# the shape of the result.
if not opt.check_chain(node, T.Elemwise):
if node.op != T.Elemwise:
return False
output = node.outputs[0]
try:
#reference is some input with the same type as the input but
#reference is some input with the same type as the output but
#that is not produced by a fill
reference = [input
for input in node.inputs
......@@ -2574,16 +2574,18 @@ def local_fill_cut(node):
return False
new_inputs = []
new = False
for input in node.inputs:
if opt.check_chain(input, T.fill):
if input.owner and input.owner.op == T.fill:
model, filling = input.owner.inputs
if encompasses_broadcastable(reference.type.broadcastable,
filling.type.broadcastable):
new_inputs.append(filling)
new = True
continue
new_inputs.append(input)
if new_inputs == node.inputs:
if not new:
return False
rval = node.op(*new_inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论