提交 d3bee2da authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fixed a bug

上级 547255d5
...@@ -1605,34 +1605,38 @@ def local_useless_elemwise(node): ...@@ -1605,34 +1605,38 @@ def local_useless_elemwise(node):
if isinstance(node.inputs[0], T.TensorConstant): if isinstance(node.inputs[0], T.TensorConstant):
const_val = T.extract_constant(node.inputs[0]) const_val = T.extract_constant(node.inputs[0])
if const_val == 0: if not isinstance(const_val, Variable):
return zeros_like(node, 1) if const_val == 0:
else: return zeros_like(node, 1)
return [node.inputs[1]] else:
return [node.inputs[1]]
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1]) const_val = T.extract_constant(node.inputs[1])
if const_val == 0: if not isinstance(const_val, Variable):
return zeros_like(node, 0) if const_val == 0:
else: return zeros_like(node, 0)
return [node.inputs[0]] else:
return [node.inputs[0]]
elif (isinstance(node.op.scalar_op, scalar.OR) and elif (isinstance(node.op.scalar_op, scalar.OR) and
len(node.inputs) == 2): len(node.inputs) == 2):
if isinstance(node.inputs[0], T.TensorConstant): if isinstance(node.inputs[0], T.TensorConstant):
const_val = T.extract_constant(node.inputs[0]) const_val = T.extract_constant(node.inputs[0])
if const_val == 0: if not isinstance(const_val, Variable):
return [node.inputs[1]] if const_val == 0:
else: return [node.inputs[1]]
return ones_like(node, 1) else:
return ones_like(node, 1)
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1]) const_val = T.extract_constant(node.inputs[1])
if const_val == 0: if not isinstance(const_val, Variable):
return [node.inputs[0]] if const_val == 0:
else: return [node.inputs[0]]
return ones_like(node, 0) else:
return ones_like(node, 0)
elif (isinstance(node.op.scalar_op, scalar.XOR) and elif (isinstance(node.op.scalar_op, scalar.XOR) and
len(node.inputs) == 2): len(node.inputs) == 2):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论