提交 9078756f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up comments in basic_opt and math_opt

上级 ecf0d14a
""" Tensor optimizations addressing the ops in basic.py.""" """ Tensor optimizations addressing the ops in basic.py."""
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
import logging import logging
import sys import sys
...@@ -588,8 +586,9 @@ def register_specialize_device(lopt, *tags, **kwargs): ...@@ -588,8 +586,9 @@ def register_specialize_device(lopt, *tags, **kwargs):
def apply_local_dimshuffle_lift(fgraph, var): def apply_local_dimshuffle_lift(fgraph, var):
# return var """
# lift recursively lift recursively
"""
if not var.owner: if not var.owner:
return var return var
new = local_dimshuffle_lift.transform(fgraph, var.owner) new = local_dimshuffle_lift.transform(fgraph, var.owner)
...@@ -598,10 +597,12 @@ def apply_local_dimshuffle_lift(fgraph, var): ...@@ -598,10 +597,12 @@ def apply_local_dimshuffle_lift(fgraph, var):
return var return var
# Checks for two types of useless dimshuffles:
# 1 - dimshuffle all dimensions in order.
# 2 - dimshuffle a broadcastable dimension.
def is_dimshuffle_useless(new_order, input): def is_dimshuffle_useless(new_order, input):
"""
Checks for two types of useless dimshuffles:
1 - dimshuffle all dimensions in order.
2 - dimshuffle a broadcastable dimension.
"""
is_useless = True is_useless = True
if len(new_order) == input.type.ndim: if len(new_order) == input.type.ndim:
all_broadcastable_dims = [ all_broadcastable_dims = [
...@@ -707,11 +708,6 @@ def local_useless_dimshuffle_in_reshape(fgraph, node): ...@@ -707,11 +708,6 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
return [ret] return [ret]
######################
# Casting operations #
######################
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([TensorFromScalar]) @local_optimizer([TensorFromScalar])
...@@ -740,9 +736,6 @@ def local_scalar_tensor_scalar(fgraph, node): ...@@ -740,9 +736,6 @@ def local_scalar_tensor_scalar(fgraph, node):
return [s] return [s]
#####################################
# ShapeFeature, Shape optimizations
#####################################
class MakeVectorPrinter: class MakeVectorPrinter:
def process(self, r, pstate): def process(self, r, pstate):
if r.owner is None: if r.owner is None:
...@@ -1844,7 +1837,6 @@ def local_canonicalize_alloc(fgraph, node): ...@@ -1844,7 +1837,6 @@ def local_canonicalize_alloc(fgraph, node):
return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
# Don't register by default.
@local_optimizer([AllocEmpty]) @local_optimizer([AllocEmpty])
def local_alloc_empty_to_zeros(fgraph, node): def local_alloc_empty_to_zeros(fgraph, node):
"""This convert AllocEmpty to Alloc of 0. """This convert AllocEmpty to Alloc of 0.
...@@ -1884,7 +1876,6 @@ def local_shape_to_shape_i(fgraph, node): ...@@ -1884,7 +1876,6 @@ def local_shape_to_shape_i(fgraph, node):
return [ret] return [ret]
# TODO: Not sure what type of node we are expecting here
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@local_optimizer([Shape_i]) @local_optimizer([Shape_i])
...@@ -1907,9 +1898,6 @@ def local_track_shape_i(fgraph, node): ...@@ -1907,9 +1898,6 @@ def local_track_shape_i(fgraph, node):
return [shape_feature.shape_of[replacement][node.op.i]] return [shape_feature.shape_of[replacement][node.op.i]]
# TODO: the other optimization for and, or, xor, le and ge see ticket #496.
@register_useless @register_useless
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_specialize @register_specialize
...@@ -2160,7 +2148,6 @@ def local_remove_all_assert(fgraph, node): ...@@ -2160,7 +2148,6 @@ def local_remove_all_assert(fgraph, node):
return [node.inputs[0]] return [node.inputs[0]]
# Disabled by default
compile.optdb["canonicalize"].register( compile.optdb["canonicalize"].register(
"local_remove_all_assert", "local_remove_all_assert",
local_remove_all_assert, local_remove_all_assert,
...@@ -2580,12 +2567,14 @@ def local_useless_switch(fgraph, node): ...@@ -2580,12 +2567,14 @@ def local_useless_switch(fgraph, node):
return False return False
# Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
# condition, to enable further simplification of their branches
# Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
@register_canonicalize @register_canonicalize
@local_optimizer([Elemwise]) @local_optimizer([Elemwise])
def local_merge_switch_same_cond(fgraph, node): def local_merge_switch_same_cond(fgraph, node):
"""
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
condition, to enable further simplification of their branches
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
"""
# node must be binary elemwise or add or mul # node must be binary elemwise or add or mul
if not isinstance(node.op, Elemwise) or not isinstance( if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, (aes.BinaryScalarOp, aes.Add, aes.Mul) node.op.scalar_op, (aes.BinaryScalarOp, aes.Add, aes.Mul)
...@@ -2613,9 +2602,6 @@ def local_merge_switch_same_cond(fgraph, node): ...@@ -2613,9 +2602,6 @@ def local_merge_switch_same_cond(fgraph, node):
] ]
#############
# Tile Opts #
#############
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
...@@ -2659,9 +2645,6 @@ def local_useless_tile(fgraph, node): ...@@ -2659,9 +2645,6 @@ def local_useless_tile(fgraph, node):
return return
##############
# Split Opts #
##############
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
...@@ -2685,9 +2668,6 @@ def local_useless_split(fgraph, node): ...@@ -2685,9 +2668,6 @@ def local_useless_split(fgraph, node):
return [out2] return [out2]
################
# Flatten Opts #
################
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@local_optimizer([Flatten]) @local_optimizer([Flatten])
...@@ -2721,11 +2701,6 @@ def local_flatten_lift(fgraph, node): ...@@ -2721,11 +2701,6 @@ def local_flatten_lift(fgraph, node):
return [e] return [e]
##################
# Reshape opts #
##################
def local_reshape_chain(op): def local_reshape_chain(op):
@local_optimizer([op]) @local_optimizer([op])
def f(fgraph, node): def f(fgraph, node):
...@@ -2865,6 +2840,7 @@ def local_useless_reshape(fgraph, node): ...@@ -2865,6 +2840,7 @@ def local_useless_reshape(fgraph, node):
# TODO later: if all the shapes except one match, we may want to # TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case. # consider it useless as well, like we do in the 1-dim case.
return False
@register_canonicalize @register_canonicalize
...@@ -2956,10 +2932,6 @@ def local_reshape_lift(fgraph, node): ...@@ -2956,10 +2932,6 @@ def local_reshape_lift(fgraph, node):
return [re] return [re]
##################
# Middleman cuts #
##################
register_canonicalize(OpRemove(tensor_copy), name="remove_tensor_copy") register_canonicalize(OpRemove(tensor_copy), name="remove_tensor_copy")
...@@ -3429,13 +3401,6 @@ def local_useless_composite(fgraph, node): ...@@ -3429,13 +3401,6 @@ def local_useless_composite(fgraph, node):
return dict(zip([node.outputs[i] for i in idx], e)) return dict(zip([node.outputs[i] for i in idx], e))
# ############################
# # Remove consider_constant #
# ############################
# Although the ops ConsiderConstant, ZeroGrad and DisconnectedGrad
# just returns the input, it should be removed from the graph to
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_useless("fast_compile") @register_useless("fast_compile")
@local_optimizer(None) @local_optimizer(None)
......
""" Tensor optimizations addressing the ops in math.py.""" """ Tensor optimizations addressing the ops in math.py."""
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
import itertools import itertools
import logging import logging
...@@ -2225,12 +2223,15 @@ def local_log1p(fgraph, node): ...@@ -2225,12 +2223,15 @@ def local_log1p(fgraph, node):
return [log1p(neg(other))] return [log1p(neg(other))]
# TODO: in canonicalize, change log10 and log2 -> log
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([log]) @local_optimizer([log])
def local_log_add_exp(fgraph, node): def local_log_add_exp(fgraph, node):
# log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max) """
``log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max)``
TODO: in canonicalize, change log10 and log2 -> log
"""
if node.op == log: if node.op == log:
z = node.inputs[0] z = node.inputs[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论