提交 8628ecf9 authored 作者: Robert P. Goldman's avatar Robert P. Goldman 提交者: Brandon T. Willard

Don't warn repeatedly for same un-optimize-able Op

Closes #23.
上级 1bf52120
...@@ -13,6 +13,7 @@ import sys ...@@ -13,6 +13,7 @@ import sys
from io import StringIO from io import StringIO
from itertools import chain from itertools import chain
from itertools import product as itertools_product from itertools import product as itertools_product
from logging import Logger
from warnings import warn from warnings import warn
import numpy as np import numpy as np
...@@ -32,28 +33,11 @@ from theano.gof import graph, ops_with_inner_function ...@@ -32,28 +33,11 @@ from theano.gof import graph, ops_with_inner_function
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.link.basic import Container, LocalLinker from theano.link.basic import Container, LocalLinker
from theano.link.utils import map_storage, raise_with_op from theano.link.utils import map_storage, raise_with_op
from theano.utils import difference, get_unbound_function from theano.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
_logger = logging.getLogger("theano.compile.debugmode") _logger: Logger = logging.getLogger("theano.compile.debugmode")
# Filter to avoid duplicating optimization warnings
class NoDuplicateOptWarningFilter(logging.Filter):
prev_msgs = set()
def filter(self, record):
msg = record.getMessage()
if msg.startswith("Optimization Warning: "):
if msg in self.prev_msgs:
return False
else:
self.prev_msgs.add(msg)
return True
return True
_logger.addFilter(NoDuplicateOptWarningFilter()) _logger.addFilter(NoDuplicateOptWarningFilter())
...@@ -538,7 +522,8 @@ def debugprint( ...@@ -538,7 +522,8 @@ def debugprint(
if used_ids is None: if used_ids is None:
used_ids = dict() used_ids = dict()
def get_id_str(obj, get_printed=True): def get_id_str(obj, get_printed=True) -> str:
id_str: str = ""
if obj in used_ids: if obj in used_ids:
id_str = used_ids[obj] id_str = used_ids[obj]
elif obj == "output": elif obj == "output":
......
...@@ -102,12 +102,11 @@ from theano.tensor.type import ( ...@@ -102,12 +102,11 @@ from theano.tensor.type import (
values_eq_approx_remove_inf_nan, values_eq_approx_remove_inf_nan,
values_eq_approx_remove_nan, values_eq_approx_remove_nan,
) )
from theano.utils import NoDuplicateOptWarningFilter
# import theano.tensor.basic as tt
_logger = logging.getLogger("theano.tensor.opt") _logger = logging.getLogger("theano.tensor.opt")
_logger.addFilter(NoDuplicateOptWarningFilter())
def _fill_chain(new_out, orig_inputs): def _fill_chain(new_out, orig_inputs):
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import hashlib import hashlib
import inspect import inspect
import logging
import os import os
import struct import struct
import subprocess import subprocess
...@@ -25,6 +26,7 @@ __all__ = [ ...@@ -25,6 +26,7 @@ __all__ = [
"output_subprocess_Popen", "output_subprocess_Popen",
"LOCAL_BITWIDTH", "LOCAL_BITWIDTH",
"PYTHON_INT_BITWIDTH", "PYTHON_INT_BITWIDTH",
"NoDuplicateOptWarningFilter",
] ]
...@@ -374,3 +376,19 @@ def flatten(a): ...@@ -374,3 +376,19 @@ def flatten(a):
return l return l
else: else:
return [a] return [a]
class NoDuplicateOptWarningFilter(logging.Filter):
"""Filter to avoid duplicating optimization warnings."""
prev_msgs = set()
def filter(self, record):
msg = record.getMessage()
if msg.startswith("Optimization Warning: "):
if msg in self.prev_msgs:
return False
else:
self.prev_msgs.add(msg)
return True
return True
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论