提交 1ddccd76 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

some formatting

上级 ddef757d
...@@ -1592,76 +1592,78 @@ def local_useless_elemwise(node): ...@@ -1592,76 +1592,78 @@ def local_useless_elemwise(node):
elif node.op.scalar_op == theano.scalar.mul and len(node.inputs) == 1: elif node.op.scalar_op == theano.scalar.mul and len(node.inputs) == 1:
# No need to copy over any stack trace # No need to copy over any stack trace
return [node.inputs[0]] return [node.inputs[0]]
elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1: elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1:
# No need to copy over any stack trace # No need to copy over any stack trace
return [node.inputs[0]] return [node.inputs[0]]
elif ( elif (
node.op.scalar_op == theano.scalar.identity node.op.scalar_op == theano.scalar.identity and
and len(node.inputs) == 1 len(node.inputs) == 1
): ):
return [node.inputs[0]] return [node.inputs[0]]
elif ( elif (
isinstance(node.op.scalar_op, scalar.AND) isinstance(node.op.scalar_op, scalar.AND) and
and len(node.inputs) == 2 len(node.inputs) == 2
): ):
if ( if (
isinstance(node.inputs[0], T.TensorConstant) isinstance(node.inputs[0], T.TensorConstant) and
and T.extract_constant(node.inputs[0]) == 1 T.extract_constant(node.inputs[0]) == 1
): ):
return [node.inputs[1]] return [node.inputs[1]]
if (
isinstance(node.inputs[1], T.TensorConstant) if (
and T.extract_constant(node.inputs[1]) == 1 isinstance(node.inputs[1], T.TensorConstant) and
T.extract_constant(node.inputs[1]) == 1
): ):
return [node.inputs[0]] return [node.inputs[0]]
if ( if (
isinstance(node.inputs[0], T.TensorConstant) isinstance(node.inputs[0], T.TensorConstant) and
and T.extract_constant(node.inputs[0]) == 0 T.extract_constant(node.inputs[0]) == 0
): ):
return zeros_like(node, 1) return zeros_like(node, 1)
if ( if (
isinstance(node.inputs[1], T.TensorConstant) isinstance(node.inputs[1], T.TensorConstant) and
and T.extract_constant(node.inputs[1]) == 0 T.extract_constant(node.inputs[1]) == 0
): ):
return zeros_like(node, 0) return zeros_like(node, 0)
elif ( elif (
isinstance(node.op.scalar_op, scalar.OR) isinstance(node.op.scalar_op, scalar.OR) and
and len(node.inputs) == 2 len(node.inputs) == 2
): ):
if ( if (
isinstance(node.inputs[0], T.TensorConstant) isinstance(node.inputs[0], T.TensorConstant) and
and T.extract_constant(node.inputs[0]) == 0 T.extract_constant(node.inputs[0]) == 0
): ):
return [node.inputs[1]] return [node.inputs[1]]
if ( if (
isinstance(node.inputs[1], T.TensorConstant) isinstance(node.inputs[1], T.TensorConstant) and
and T.extract_constant(node.inputs[1]) == 0 T.extract_constant(node.inputs[1]) == 0
): ):
return [node.inputs[0]] return [node.inputs[0]]
if ( if (
isinstance(node.inputs[0], T.TensorConstant) isinstance(node.inputs[0], T.TensorConstant) and
and T.extract_constant(node.inputs[0]) == 1 T.extract_constant(node.inputs[0]) == 1
): ):
return ones_like(node, 1) return ones_like(node, 1)
if ( if (
isinstance(node.inputs[1], T.TensorConstant) isinstance(node.inputs[1], T.TensorConstant) and
and T.extract_constant(node.inputs[1]) == 1 T.extract_constant(node.inputs[1]) == 1
): ):
return ones_like(node, 0) return ones_like(node, 0)
elif ( elif (
isinstance(node.op.scalar_op, scalar.XOR) isinstance(node.op.scalar_op, scalar.XOR) and
and len(node.inputs) == 2 len(node.inputs) == 2
): ):
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
return zeros_like(node, 0) return zeros_like(node, 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论