提交 643e11af authored 作者: Dustin Webb's avatar Dustin Webb 提交者: Amjad Almahairi

Added following optimizations:

T.and(x,1) -> x T.and(x,0) -> zeros_likes(x) T.or(x,0) -> x T.or(x,1) -> ones_likes(x) T.xor(x,x) -> zeros_likes(x) T.le(x,x) -> ones_likes(x) T.ge(x,x) -> ones_likes(x)
上级 aa49be05
......@@ -1553,6 +1553,16 @@ def local_useless_elemwise(node):
"""
if isinstance(node.op, T.Elemwise):
def zeros_like(node, in_idx):
#it is the same var in the graph. That will always be true
return [T.fill(node.inputs[in_idx],
T.constant(0.0, dtype=node.outputs[0].type.dtype))]
def ones_like(node, in_idx):
#it is the same var in the graph. That will always be true
return [T.fill(node.inputs[in_idx],
T.constant(1.0, dtype=node.outputs[0].type.dtype))]
if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be true
......@@ -1577,13 +1587,90 @@ def local_useless_elemwise(node):
elif node.op.scalar_op == theano.scalar.mul and len(node.inputs) == 1:
# No need to copy over any stack trace
return [node.inputs[0]]
elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1:
# No need to copy over any stack trace
return [node.inputs[0]]
elif (node.op.scalar_op == theano.scalar.identity and
len(node.inputs) == 1):
# No need to copy over any stack trace
if (
node.op.scalar_op == theano.scalar.identity and
len(node.inputs) == 1
):
return [node.inputs[0]]
if (
isinstance(node.op.scalar_op, theano.scalar.basic.AND) and
len(node.inputs) == 2
):
if (
isinstance(node.inputs[0], T.TensorConstant) and
node.inputs[0].data == 1.0
):
return [node.inputs[1]]
if (
isinstance(node.inputs[1], T.TensorConstant) and
node.inputs[1].data == 1.0
):
return [node.inputs[0]]
if (
isinstance(node.inputs[0], T.TensorConstant) and
node.inputs[0].data == 0.0
):
return zeros_like(node, 1)
if (
isinstance(node.inputs[1], T.TensorConstant) and
node.inputs[1].data == 0.0
):
return zeros_like(node, 0)
if (
isinstance(node.op.scalar_op, theano.scalar.basic.OR) and
len(node.inputs) == 2
):
if (
isinstance(node.inputs[0], T.TensorConstant) and
node.inputs[0].data == 0.0
):
return [node.inputs[1]]
if (
isinstance(node.inputs[1], T.TensorConstant) and
node.inputs[1].data == 0.0
):
return [node.inputs[0]]
if (
isinstance(node.inputs[0], T.TensorConstant) and
node.inputs[0].data == 1.0
):
return ones_like(node, 1)
if (
isinstance(node.inputs[1], T.TensorConstant) and
node.inputs[1].data == 1.0
):
return ones_like(node, 0)
if (
isinstance(node.op.scalar_op, theano.scalar.basic.XOR) and
len(node.inputs) == 2
):
if node.inputs[0] == node.inputs[1]:
return zeros_like(node, 0)
if (
(
isinstance(node.op.scalar_op, theano.scalar.basic.LE) or
isinstance(node.op.scalar_op, theano.scalar.basic.GE)
) and
len(node.inputs) == 2
):
if node.inputs[0] == node.inputs[1]:
return ones_like(node, 0)
@register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论