提交 d3bee2da authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fixed a bug

上级 547255d5
...@@ -1605,6 +1605,7 @@ def local_useless_elemwise(node): ...@@ -1605,6 +1605,7 @@ def local_useless_elemwise(node):
if isinstance(node.inputs[0], T.TensorConstant): if isinstance(node.inputs[0], T.TensorConstant):
const_val = T.extract_constant(node.inputs[0]) const_val = T.extract_constant(node.inputs[0])
if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return zeros_like(node, 1) return zeros_like(node, 1)
else: else:
...@@ -1612,6 +1613,7 @@ def local_useless_elemwise(node): ...@@ -1612,6 +1613,7 @@ def local_useless_elemwise(node):
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1]) const_val = T.extract_constant(node.inputs[1])
if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return zeros_like(node, 0) return zeros_like(node, 0)
else: else:
...@@ -1622,6 +1624,7 @@ def local_useless_elemwise(node): ...@@ -1622,6 +1624,7 @@ def local_useless_elemwise(node):
if isinstance(node.inputs[0], T.TensorConstant): if isinstance(node.inputs[0], T.TensorConstant):
const_val = T.extract_constant(node.inputs[0]) const_val = T.extract_constant(node.inputs[0])
if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[1]] return [node.inputs[1]]
else: else:
...@@ -1629,6 +1632,7 @@ def local_useless_elemwise(node): ...@@ -1629,6 +1632,7 @@ def local_useless_elemwise(node):
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1]) const_val = T.extract_constant(node.inputs[1])
if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[0]] return [node.inputs[0]]
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论