提交 e7064712 authored 作者: Reyhane Askari's avatar Reyhane Askari

fix for optdb.register

上级 7c341d1d
...@@ -226,10 +226,10 @@ optdb.register('add_destroy_handler', AddDestroyHandler(), ...@@ -226,10 +226,10 @@ optdb.register('add_destroy_handler', AddDestroyHandler(),
optdb.register('merge3', gof.MergeOptimizer(), optdb.register('merge3', gof.MergeOptimizer(),
100, 'fast_run', 'merge') 100, 'fast_run', 'merge')
if theano.config.check_stack_trace in ['check_all', 'check_and_skip']: if theano.config.check_stack_trace in ['raise', 'warn', 'log']:
tags = ('fast_run', 'fast_compile') tags = ('fast_run', 'fast_compile')
if theano.config.check_stack_trace == 'not_checking': if theano.config.check_stack_trace == 'off':
tags = () tags = ()
optdb.register('CheckStackTrace', optdb.register('CheckStackTrace',
......
...@@ -3043,6 +3043,8 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'): ...@@ -3043,6 +3043,8 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
class CheckStrackTraceFeature(object): class CheckStrackTraceFeature(object):
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
# In optdb we only register the CheckStackTraceOptimization when
# theano.config.check_stack_trace is not off but we also double check here.
if theano.config.check_stack_trace != 'off' and not check_stack_trace(fgraph, 'all'): if theano.config.check_stack_trace != 'off' and not check_stack_trace(fgraph, 'all'):
if theano.config.check_stack_trace == 'warn': if theano.config.check_stack_trace == 'warn':
warnings.warn( warnings.warn(
...@@ -3060,7 +3062,7 @@ class CheckStrackTraceFeature(object): ...@@ -3060,7 +3062,7 @@ class CheckStrackTraceFeature(object):
class CheckStackTraceOptimization(Optimizer): class CheckStackTraceOptimization(Optimizer):
"""Optimizer that serves to add ShapeFeature as an fgraph feature.""" """Optimizer that serves to add CheckStackTraceOptimization as an fgraph feature."""
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
if not hasattr(fgraph, 'CheckStrackTraceFeature'): if not hasattr(fgraph, 'CheckStrackTraceFeature'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论