提交 edf974b4 authored 作者: Frederic's avatar Frederic

Allow opt registered in name EquilibriumDB to do not be enable by the…

Allow opt registered in name EquilibriumDB to do not be enable by the EquilibrumDB name and use that for local_remove_all_assert opt.
上级 7c71f3ad
......@@ -32,7 +32,21 @@ class DB(object):
self.name = None # will be reset by register
#(via obj.name by the thing doing the registering)
def register(self, name, obj, *tags):
def register(self, name, obj, *tags, **kwargs):
"""
:param name: name of the optimizer.
:param obj: the optimizer to register.
:param tags: tag name that allow to select the optimizer.
:param kwargs: If non empty, should contain
only use_db_name_as_tag=False.
By default, all optimizations registered in EquilibriumDB
are selected when the EquilibriumDB name is used as a
tag. We do not want this behavior for some optimizer like
local_remove_all_assert. use_db_name_as_tag=False remove
that behavior. This mean only the optimizer name and the
tags specified will enable that optimization.
"""
# N.B. obj is not an instance of class Optimizer.
# It is an instance of a DB.In the tests for example,
# this is not always the case.
......@@ -42,9 +56,12 @@ class DB(object):
raise ValueError('The name of the object cannot be an existing'
' tag or the name of another existing object.',
obj, name)
if self.name is not None:
tags = tags + (self.name,)
if kwargs:
assert "use_db_name_as_tag" in kwargs
assert kwargs["use_db_name_as_tag"] is False
else:
if self.name is not None:
tags = tags + (self.name,)
obj.name = name
# This restriction is there because in many place we suppose that
# something in the DB is there only once.
......@@ -155,6 +172,10 @@ class Query(object):
if isinstance(self.exclude, (list, tuple)):
self.exclude = OrderedSet(self.exclude)
def __str__(self):
return "Query{inc=%s,ex=%s,require=%s,subquery=%s,position_cutoff=%d}" % (
self.include, self.exclude, self.require, self.subquery, self.position_cutoff)
#add all opt with this tag
def including(self, *tags):
return Query(self.include.union(tags),
......
......@@ -1596,8 +1596,8 @@ def local_remove_all_assert(node):
return [node.inputs[0]]
# Disabled by default
compile.optdb['canonicalize'].register('local_remove_all_assert',
local_remove_all_assert)
local_remove_all_assert,
use_db_name_as_tag=False)
@register_specialize
@gof.local_optimizer([T.Elemwise])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论