提交 396e2f1e authored 作者: Amjad Almahairi's avatar Amjad Almahairi

minor modifs

上级 b65abe5c
...@@ -3296,40 +3296,21 @@ def local_useless_switch(node): ...@@ -3296,40 +3296,21 @@ def local_useless_switch(node):
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
left = node.inputs[1] left = node.inputs[1]
right = node.inputs[2] right = node.inputs[2]
cond_var = node.inputs[0]
if (cond.owner and if (cond_var.owner and
isinstance(cond.owner.op, T.Elemwise) and isinstance(cond_var.owner.op, T.Elemwise) and
isinstance(cond.owner.op.scalar_op, scalar.LE) and isinstance(cond_var.owner.op.scalar_op, scalar.LE) and
cond.owner.inputs[0].owner and cond_var.owner.inputs[0].owner and
isinstance(cond.owner.inputs[0].owner.op, Shape_i) and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) and
T.extract_constant(cond.owner.inputs[1]) == 0 and T.extract_constant(cond_var.owner.inputs[1]) == 0 and
T.extract_constant(left) == 0 and T.extract_constant(left) == 0 and
right is cond.owner.inputs[0]): right is cond_var.owner.inputs[0]):
assert right.type == node.outputs[0].type assert right.type == node.outputs[0].type
return [right] return [right]
return False return False
return False return False
#@register_canonicalize
#@register_specialize
@gof.local_optimizer([Shape_i])
def local_shape_i_infered(node):
if not isinstance(node.op, Shape_i):
return
if not hasattr(node, 'fgraph'):
return
if not hasattr(node.fgraph, 'shape_feature'):
return
try:
shp = node.fgraph.shape_feature.shape_of[node.inputs[0]][node.op.i]
c = get_scalar_constant_value(shp)
import pdb;pdb.set_trace()
return [T.constant(c, dtype=node.outputs[0].dtype)]
except NotScalarConstantError:
pass
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_switch_sink(node): def local_mul_switch_sink(node):
...@@ -4233,9 +4214,10 @@ def local_elemwise_sub_zeros(node): ...@@ -4233,9 +4214,10 @@ def local_elemwise_sub_zeros(node):
def local_useless_elemwise_comparison(node): def local_useless_elemwise_comparison(node):
"""... """...
:note: Those case appear in the graph generated around scan. This :note: These cases appear in the graph generated by scan.
don't remove much computation, but make the graph easier to These optimizations will not reduce computation,
read. but will make the graph easier to read.
# Comparing to itself is constant # Comparing to itself is constant
Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论