提交 127c347b authored 作者: Frederic's avatar Frederic

Add test value during opt.

上级 77da4e17
......@@ -4600,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else:
tmp_s_input.append(scalar.Scalar(
ii.dtype).make_variable())
tmp = scalar.Scalar(ii.dtype).make_variable()
try:
tmp.tag.test_value = gof.op.get_test_value(ii).flatten()[0]
except AttributeError:
pass
tmp_s_input.append(tmp)
tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input)
......@@ -4651,6 +4655,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
s = s_inputs[inputs.index(i)]
else:
s = scalar.Scalar(i.dtype).make_variable()
try:
s.tag.test_value = gof.op.get_test_value(i).flatten()[0]
except AttributeError:
pass
inputs.append(i)
s_inputs.append(s)
s_g.append(s)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论