提交 b3084a31 authored 作者: Frederic's avatar Frederic

Faster opt local_fill_cut

上级 ddfb387c
...@@ -2559,12 +2559,12 @@ def local_fill_cut(node): ...@@ -2559,12 +2559,12 @@ def local_fill_cut(node):
# scalars, but we can't ignore the large matrix because it gives # scalars, but we can't ignore the large matrix because it gives
# the shape of the result. # the shape of the result.
if not opt.check_chain(node, T.Elemwise): if node.op != T.Elemwise:
return False return False
output = node.outputs[0] output = node.outputs[0]
try: try:
#reference is some input with the same type as the input but #reference is some input with the same type as the output but
#that is not produced by a fill #that is not produced by a fill
reference = [input reference = [input
for input in node.inputs for input in node.inputs
...@@ -2574,16 +2574,18 @@ def local_fill_cut(node): ...@@ -2574,16 +2574,18 @@ def local_fill_cut(node):
return False return False
new_inputs = [] new_inputs = []
new = False
for input in node.inputs: for input in node.inputs:
if opt.check_chain(input, T.fill): if input.owner and input.owner.op == T.fill:
model, filling = input.owner.inputs model, filling = input.owner.inputs
if encompasses_broadcastable(reference.type.broadcastable, if encompasses_broadcastable(reference.type.broadcastable,
filling.type.broadcastable): filling.type.broadcastable):
new_inputs.append(filling) new_inputs.append(filling)
new = True
continue continue
new_inputs.append(input) new_inputs.append(input)
if new_inputs == node.inputs: if not new:
return False return False
rval = node.op(*new_inputs) rval = node.op(*new_inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论