提交 3128a44c authored 作者: James Bergstra's avatar James Bergstra

added local_mul_to_sqr optimization

上级 b7dda09a
...@@ -1332,6 +1332,18 @@ def local_pow_canonicalize(node): ...@@ -1332,6 +1332,18 @@ def local_pow_canonicalize(node):
return False return False
register_canonicalize(local_pow_canonicalize) register_canonicalize(local_pow_canonicalize)
@register_specialize
@gof.local_optimizer([T.mul])
def local_mul_to_sqr(node):
"""x*x -> sqr(x)
This is faster on the GPU when memory fetching is a big part of the computation time.
"""
if node.op == T.mul:
if len(node.inputs)==2:
if node.inputs[0] is node.inputs[1]:
return [T.sqr(node.inputs[0])]
@gof.local_optimizer([T.pow]) @gof.local_optimizer([T.pow])
def local_pow_specialize(node): def local_pow_specialize(node):
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills. #here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
...@@ -1370,7 +1382,9 @@ register_specialize(local_pow_specialize) ...@@ -1370,7 +1382,9 @@ register_specialize(local_pow_specialize)
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_specialize(node): def local_mul_specialize(node):
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills. """Remove special-case constants from mul arguments
"""
# here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
# #
# at this point [post canonicalize], mul() may have many inputs. # at this point [post canonicalize], mul() may have many inputs.
if node.op == T.mul: if node.op == T.mul:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论