提交 8628ecf9 authored 作者: Robert P. Goldman's avatar Robert P. Goldman 提交者: Brandon T. Willard

Don't warn repeatedly for same un-optimize-able Op

Closes #23.
上级 1bf52120
......@@ -13,6 +13,7 @@ import sys
from io import StringIO
from itertools import chain
from itertools import product as itertools_product
from logging import Logger
from warnings import warn
import numpy as np
......@@ -32,28 +33,11 @@ from theano.gof import graph, ops_with_inner_function
from theano.gof.utils import MethodNotDefined
from theano.link.basic import Container, LocalLinker
from theano.link.utils import map_storage, raise_with_op
from theano.utils import difference, get_unbound_function
from theano.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function
__docformat__ = "restructuredtext en"
_logger = logging.getLogger("theano.compile.debugmode")
# Filter to avoid duplicating optimization warnings
class NoDuplicateOptWarningFilter(logging.Filter):
prev_msgs = set()
def filter(self, record):
msg = record.getMessage()
if msg.startswith("Optimization Warning: "):
if msg in self.prev_msgs:
return False
else:
self.prev_msgs.add(msg)
return True
return True
_logger: Logger = logging.getLogger("theano.compile.debugmode")
_logger.addFilter(NoDuplicateOptWarningFilter())
......@@ -538,7 +522,8 @@ def debugprint(
if used_ids is None:
used_ids = dict()
def get_id_str(obj, get_printed=True):
def get_id_str(obj, get_printed=True) -> str:
id_str: str = ""
if obj in used_ids:
id_str = used_ids[obj]
elif obj == "output":
......
......@@ -102,12 +102,11 @@ from theano.tensor.type import (
values_eq_approx_remove_inf_nan,
values_eq_approx_remove_nan,
)
# import theano.tensor.basic as tt
from theano.utils import NoDuplicateOptWarningFilter
_logger = logging.getLogger("theano.tensor.opt")
_logger.addFilter(NoDuplicateOptWarningFilter())
def _fill_chain(new_out, orig_inputs):
......
......@@ -3,6 +3,7 @@
import hashlib
import inspect
import logging
import os
import struct
import subprocess
......@@ -25,6 +26,7 @@ __all__ = [
"output_subprocess_Popen",
"LOCAL_BITWIDTH",
"PYTHON_INT_BITWIDTH",
"NoDuplicateOptWarningFilter",
]
......@@ -374,3 +376,19 @@ def flatten(a):
return l
else:
return [a]
class NoDuplicateOptWarningFilter(logging.Filter):
"""Filter to avoid duplicating optimization warnings."""
prev_msgs = set()
def filter(self, record):
msg = record.getMessage()
if msg.startswith("Optimization Warning: "):
if msg in self.prev_msgs:
return False
else:
self.prev_msgs.add(msg)
return True
return True
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论