提交 127c347b authored 作者: Frederic's avatar Frederic

Add test value during opt.

上级 77da4e17
...@@ -4600,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4600,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
elif ii in tmp_input: elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else: else:
tmp_s_input.append(scalar.Scalar( tmp = scalar.Scalar(ii.dtype).make_variable()
ii.dtype).make_variable()) try:
tmp.tag.test_value = gof.op.get_test_value(ii).flatten()[0]
except AttributeError:
pass
tmp_s_input.append(tmp)
tmp_input.append(ii) tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1]) tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input) s_op = i.owner.op.scalar_op(*tmp_s_input)
...@@ -4651,6 +4655,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4651,6 +4655,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
s = s_inputs[inputs.index(i)] s = s_inputs[inputs.index(i)]
else: else:
s = scalar.Scalar(i.dtype).make_variable() s = scalar.Scalar(i.dtype).make_variable()
try:
s.tag.test_value = gof.op.get_test_value(i).flatten()[0]
except AttributeError:
pass
inputs.append(i) inputs.append(i)
s_inputs.append(s) s_inputs.append(s)
s_g.append(s) s_g.append(s)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论