提交 9078756f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up comments in basic_opt and math_opt

上级 ecf0d14a
""" Tensor optimizations addressing the ops in basic.py."""
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
import logging
import sys
......@@ -588,8 +586,9 @@ def register_specialize_device(lopt, *tags, **kwargs):
def apply_local_dimshuffle_lift(fgraph, var):
# return var
# lift recursively
"""
lift recursively
"""
if not var.owner:
return var
new = local_dimshuffle_lift.transform(fgraph, var.owner)
......@@ -598,10 +597,12 @@ def apply_local_dimshuffle_lift(fgraph, var):
return var
# Checks for two types of useless dimshuffles:
# 1 - dimshuffle all dimensions in order.
# 2 - dimshuffle a broadcastable dimension.
def is_dimshuffle_useless(new_order, input):
"""
Checks for two types of useless dimshuffles:
1 - dimshuffle all dimensions in order.
2 - dimshuffle a broadcastable dimension.
"""
is_useless = True
if len(new_order) == input.type.ndim:
all_broadcastable_dims = [
......@@ -707,11 +708,6 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
return [ret]
######################
# Casting operations #
######################
@register_canonicalize
@register_specialize
@local_optimizer([TensorFromScalar])
......@@ -740,9 +736,6 @@ def local_scalar_tensor_scalar(fgraph, node):
return [s]
#####################################
# ShapeFeature, Shape optimizations
#####################################
class MakeVectorPrinter:
def process(self, r, pstate):
if r.owner is None:
......@@ -1844,7 +1837,6 @@ def local_canonicalize_alloc(fgraph, node):
return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
# Don't register by default.
@local_optimizer([AllocEmpty])
def local_alloc_empty_to_zeros(fgraph, node):
"""This convert AllocEmpty to Alloc of 0.
......@@ -1884,7 +1876,6 @@ def local_shape_to_shape_i(fgraph, node):
return [ret]
# TODO: Not sure what type of node we are expecting here
@register_specialize
@register_canonicalize
@local_optimizer([Shape_i])
......@@ -1907,9 +1898,6 @@ def local_track_shape_i(fgraph, node):
return [shape_feature.shape_of[replacement][node.op.i]]
# TODO: the other optimization for and, or, xor, le and ge see ticket #496.
@register_useless
@register_canonicalize("fast_compile")
@register_specialize
......@@ -2160,7 +2148,6 @@ def local_remove_all_assert(fgraph, node):
return [node.inputs[0]]
# Disabled by default
compile.optdb["canonicalize"].register(
"local_remove_all_assert",
local_remove_all_assert,
......@@ -2580,12 +2567,14 @@ def local_useless_switch(fgraph, node):
return False
# Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
# condition, to enable further simplification of their branches
# Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
@register_canonicalize
@local_optimizer([Elemwise])
def local_merge_switch_same_cond(fgraph, node):
"""
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
condition, to enable further simplification of their branches
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
"""
# node must be binary elemwise or add or mul
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, (aes.BinaryScalarOp, aes.Add, aes.Mul)
......@@ -2613,9 +2602,6 @@ def local_merge_switch_same_cond(fgraph, node):
]
#############
# Tile Opts #
#############
@register_useless
@register_canonicalize
@register_stabilize
......@@ -2659,9 +2645,6 @@ def local_useless_tile(fgraph, node):
return
##############
# Split Opts #
##############
@register_useless
@register_canonicalize
@register_specialize
......@@ -2685,9 +2668,6 @@ def local_useless_split(fgraph, node):
return [out2]
################
# Flatten Opts #
################
@register_canonicalize
@register_stabilize
@local_optimizer([Flatten])
......@@ -2721,11 +2701,6 @@ def local_flatten_lift(fgraph, node):
return [e]
##################
# Reshape opts #
##################
def local_reshape_chain(op):
@local_optimizer([op])
def f(fgraph, node):
......@@ -2865,6 +2840,7 @@ def local_useless_reshape(fgraph, node):
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
return False
@register_canonicalize
......@@ -2956,10 +2932,6 @@ def local_reshape_lift(fgraph, node):
return [re]
##################
# Middleman cuts #
##################
register_canonicalize(OpRemove(tensor_copy), name="remove_tensor_copy")
......@@ -3429,13 +3401,6 @@ def local_useless_composite(fgraph, node):
return dict(zip([node.outputs[i] for i in idx], e))
# ############################
# # Remove consider_constant #
# ############################
# Although the ops ConsiderConstant, ZeroGrad and DisconnectedGrad
# just returns the input, it should be removed from the graph to
@register_canonicalize("fast_compile")
@register_useless("fast_compile")
@local_optimizer(None)
......
""" Tensor optimizations addressing the ops in math.py."""
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
import itertools
import logging
......@@ -2225,12 +2223,15 @@ def local_log1p(fgraph, node):
return [log1p(neg(other))]
# TODO: in canonicalize, change log10 and log2 -> log
@register_stabilize
@register_specialize
@local_optimizer([log])
def local_log_add_exp(fgraph, node):
# log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max)
"""
``log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max)``
TODO: in canonicalize, change log10 and log2 -> log
"""
if node.op == log:
z = node.inputs[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论