提交 a63cd50f authored 作者: Amjad Almahairi's avatar Amjad Almahairi

some refactoring

上级 1ddccd76
...@@ -1596,76 +1596,47 @@ def local_useless_elemwise(node): ...@@ -1596,76 +1596,47 @@ def local_useless_elemwise(node):
elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1: elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1:
# No need to copy over any stack trace # No need to copy over any stack trace
return [node.inputs[0]] return [node.inputs[0]]
elif node.op.scalar_op == theano.scalar.identity and \
elif ( len(node.inputs) == 1:
node.op.scalar_op == theano.scalar.identity and
len(node.inputs) == 1
):
return [node.inputs[0]] return [node.inputs[0]]
elif ( elif isinstance(node.op.scalar_op, scalar.AND) and \
isinstance(node.op.scalar_op, scalar.AND) and len(node.inputs) == 2:
len(node.inputs) == 2
):
if (
isinstance(node.inputs[0], T.TensorConstant) and
T.extract_constant(node.inputs[0]) == 1
):
return [node.inputs[1]]
if (
isinstance(node.inputs[1], T.TensorConstant) and
T.extract_constant(node.inputs[1]) == 1
):
return [node.inputs[0]]
if (
isinstance(node.inputs[0], T.TensorConstant) and
T.extract_constant(node.inputs[0]) == 0
):
return zeros_like(node, 1)
if (
isinstance(node.inputs[1], T.TensorConstant) and
T.extract_constant(node.inputs[1]) == 0
):
return zeros_like(node, 0)
elif (
isinstance(node.op.scalar_op, scalar.OR) and
len(node.inputs) == 2
):
if ( if isinstance(node.inputs[0], T.TensorConstant):
isinstance(node.inputs[0], T.TensorConstant) and const_val = T.extract_constant(node.inputs[0])
T.extract_constant(node.inputs[0]) == 0 if const_val == 1:
): return [node.inputs[1]]
return [node.inputs[1]] elif const_val == 0:
return zeros_like(node, 1)
if (
isinstance(node.inputs[1], T.TensorConstant) and if isinstance(node.inputs[1], T.TensorConstant):
T.extract_constant(node.inputs[1]) == 0 const_val = T.extract_constant(node.inputs[1])
): if const_val == 1:
return [node.inputs[0]] return [node.inputs[0]]
elif const_val == 0:
if ( return zeros_like(node, 0)
isinstance(node.inputs[0], T.TensorConstant) and
T.extract_constant(node.inputs[0]) == 1 elif isinstance(node.op.scalar_op, scalar.OR) and \
): len(node.inputs) == 2:
return ones_like(node, 1)
if isinstance(node.inputs[0], T.TensorConstant):
if ( const_val = T.extract_constant(node.inputs[0])
isinstance(node.inputs[1], T.TensorConstant) and if const_val == 0:
T.extract_constant(node.inputs[1]) == 1 return [node.inputs[1]]
): elif const_val == 1:
return ones_like(node, 0) return ones_like(node, 1)
elif ( if isinstance(node.inputs[1], T.TensorConstant):
isinstance(node.op.scalar_op, scalar.XOR) and const_val = T.extract_constant(node.inputs[1])
len(node.inputs) == 2 if const_val == 0:
): return [node.inputs[0]]
if node.inputs[0] == node.inputs[1]: if const_val == 1:
return ones_like(node, 0)
elif isinstance(node.op.scalar_op, scalar.XOR) and \
len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
return zeros_like(node, 0) return zeros_like(node, 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论