提交 e0ea29ae authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use TensorType.dtype instead of TensorVariable.dtype in local_elemwise_fusion_op

上级 95f0c13d
......@@ -3112,7 +3112,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else:
tmp = aes.get_scalar_type(ii.dtype).make_variable()
tmp = aes.get_scalar_type(ii.type.dtype).make_variable()
try:
tv = get_test_value(ii)
if tv.size > 0:
......@@ -3180,7 +3180,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
if inputs.count(i) == node.inputs.count(i):
s = s_inputs[inputs.index(i)]
else:
s = aes.get_scalar_type(i.dtype).make_variable()
s = aes.get_scalar_type(i.type.dtype).make_variable()
try:
if config.compute_test_value != "off":
v = get_test_value(i)
......@@ -3232,7 +3232,7 @@ your code will run correctly, but may be slower."""
new_node = maker(node, composite_op)(*inputs).owner
assert len(new_node.outputs) == 1
assert node.outputs[0].dtype == new_node.outputs[0].dtype
assert node.outputs[0].type.dtype == new_node.outputs[0].type.dtype
if len(new_node.inputs) > max_nb_input:
_logger.warning(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论