提交 25a58173 authored 作者: Dustin Webb's avatar Dustin Webb 提交者: Amjad Almahairi

Added some documentation and a minor optimization.

上级 7b924d44
...@@ -1550,16 +1550,23 @@ def local_useless_elemwise(node): ...@@ -1550,16 +1550,23 @@ def local_useless_elemwise(node):
mul(x) -> x mul(x) -> x
add(x) -> x add(x) -> x
identity(x) -> x identity(x) -> x
and(x,1) -> x
and(x,0) -> zeros_like(x)
or(x,0) -> x
or(x,1) -> ones_like(x)
xor(x,x) -> zeros_like(x)
le(x,x) -> ones_like(x)
ge(x,x) -> ones_like(x)
""" """
if isinstance(node.op, T.Elemwise): if isinstance(node.op, T.Elemwise):
def zeros_like(node, in_idx): def zeros_like(node, in_idx):
#it is the same var in the graph. That will always be true # it is the same var in the graph. That will always be true
return [T.fill(node.inputs[in_idx], return [T.fill(node.inputs[in_idx],
T.constant(0.0, dtype=node.outputs[0].type.dtype))] T.constant(0.0, dtype=node.outputs[0].type.dtype))]
def ones_like(node, in_idx): def ones_like(node, in_idx):
#it is the same var in the graph. That will always be true # it is the same var in the graph. That will always be true
return [T.fill(node.inputs[in_idx], return [T.fill(node.inputs[in_idx],
T.constant(1.0, dtype=node.outputs[0].type.dtype))] T.constant(1.0, dtype=node.outputs[0].type.dtype))]
...@@ -1591,13 +1598,13 @@ def local_useless_elemwise(node): ...@@ -1591,13 +1598,13 @@ def local_useless_elemwise(node):
elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1: elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1:
# No need to copy over any stack trace # No need to copy over any stack trace
return [node.inputs[0]] return [node.inputs[0]]
if (
node.op.scalar_op == theano.scalar.identity and elif (
len(node.inputs) == 1 node.op.scalar_op == theano.scalar.identity
and len(node.inputs) == 1
): ):
return [node.inputs[0]] return [node.inputs[0]]
elif (
if (
isinstance(node.op.scalar_op, scalar.AND) isinstance(node.op.scalar_op, scalar.AND)
and len(node.inputs) == 2 and len(node.inputs) == 2
): ):
...@@ -1626,8 +1633,7 @@ def local_useless_elemwise(node): ...@@ -1626,8 +1633,7 @@ def local_useless_elemwise(node):
and T.extract_constant(node.inputs[1]) == 0 and T.extract_constant(node.inputs[1]) == 0
): ):
return zeros_like(node, 0) return zeros_like(node, 0)
elif (
if (
isinstance(node.op.scalar_op, scalar.OR) isinstance(node.op.scalar_op, scalar.OR)
and len(node.inputs) == 2 and len(node.inputs) == 2
): ):
...@@ -1655,8 +1661,7 @@ def local_useless_elemwise(node): ...@@ -1655,8 +1661,7 @@ def local_useless_elemwise(node):
and T.extract_constant(node.inputs[1]) == 1 and T.extract_constant(node.inputs[1]) == 1
): ):
return ones_like(node, 0) return ones_like(node, 0)
elif (
if (
isinstance(node.op.scalar_op, scalar.XOR) isinstance(node.op.scalar_op, scalar.XOR)
and len(node.inputs) == 2 and len(node.inputs) == 2
): ):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论