提交 fe81c12f authored 作者: Dustin Webb's avatar Dustin Webb 提交者: Amjad Almahairi

Minor improvements

上级 643e11af
...@@ -1603,26 +1603,27 @@ def local_useless_elemwise(node): ...@@ -1603,26 +1603,27 @@ def local_useless_elemwise(node):
): ):
if ( if (
isinstance(node.inputs[0], T.TensorConstant) and isinstance(node.inputs[0], T.TensorConstant)
node.inputs[0].data == 1.0 and T.extract_constant(node.inputs[0]) == 1
): ):
return [node.inputs[1]] return [node.inputs[1]]
if ( if (
isinstance(node.inputs[1], T.TensorConstant) and
node.inputs[1].data == 1.0 isinstance(node.inputs[1], T.TensorConstant)
and T.extract_constant(node.inputs[1]) == 1
): ):
return [node.inputs[0]] return [node.inputs[0]]
if ( if (
isinstance(node.inputs[0], T.TensorConstant) and isinstance(node.inputs[0], T.TensorConstant)
node.inputs[0].data == 0.0 and T.extract_constant(node.inputs[0]) == 0
): ):
return zeros_like(node, 1) return zeros_like(node, 1)
if ( if (
isinstance(node.inputs[1], T.TensorConstant) and isinstance(node.inputs[1], T.TensorConstant)
node.inputs[1].data == 0.0 and T.extract_constant(node.inputs[1]) == 0
): ):
return zeros_like(node, 0) return zeros_like(node, 0)
...@@ -1632,26 +1633,26 @@ def local_useless_elemwise(node): ...@@ -1632,26 +1633,26 @@ def local_useless_elemwise(node):
): ):
if ( if (
isinstance(node.inputs[0], T.TensorConstant) and isinstance(node.inputs[0], T.TensorConstant)
node.inputs[0].data == 0.0 and T.extract_constant(node.inputs[0]) == 0
): ):
return [node.inputs[1]] return [node.inputs[1]]
if ( if (
isinstance(node.inputs[1], T.TensorConstant) and isinstance(node.inputs[1], T.TensorConstant)
node.inputs[1].data == 0.0 and T.extract_constant(node.inputs[1]) == 0
): ):
return [node.inputs[0]] return [node.inputs[0]]
if ( if (
isinstance(node.inputs[0], T.TensorConstant) and isinstance(node.inputs[0], T.TensorConstant)
node.inputs[0].data == 1.0 and T.extract_constant(node.inputs[0]) == 1
): ):
return ones_like(node, 1) return ones_like(node, 1)
if ( if (
isinstance(node.inputs[1], T.TensorConstant) and isinstance(node.inputs[1], T.TensorConstant)
node.inputs[1].data == 1.0 and T.extract_constant(node.inputs[1]) == 1
): ):
return ones_like(node, 0) return ones_like(node, 0)
...@@ -1663,11 +1664,11 @@ def local_useless_elemwise(node): ...@@ -1663,11 +1664,11 @@ def local_useless_elemwise(node):
return zeros_like(node, 0) return zeros_like(node, 0)
if ( if (
( isinstance(
isinstance(node.op.scalar_op, theano.scalar.basic.LE) or node.op.scalar_op,
isinstance(node.op.scalar_op, theano.scalar.basic.GE) (theano.scalar.basic.LE, theano.scalar.basic.GE)
) and )
len(node.inputs) == 2 and len(node.inputs) == 2
): ):
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
return ones_like(node, 0) return ones_like(node, 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论