提交 f7d9a99f authored 作者: Dustin Webb's avatar Dustin Webb

Made robust to stack manipulations and added test.

上级 ac0e0430
...@@ -5,13 +5,15 @@ __docformat__ = "restructuredtext en" ...@@ -5,13 +5,15 @@ __docformat__ = "restructuredtext en"
import logging import logging
_logger = logging.getLogger('theano.compile.function') _logger = logging.getLogger('theano.compile.function')
import traceback as tb
import re
from theano.compile.io import In from theano.compile.io import In
from theano.compile.function_module import orig_function from theano.compile.function_module import orig_function
from theano.compile.pfunc import pfunc from theano.compile.pfunc import pfunc
from numpy import any # to work in python 2.4 from numpy import any # to work in python 2.4
import warnings import warnings
from theano import gof from theano import gof
import traceback as tb
def function(inputs, outputs=None, mode=None, updates=None, givens=None, def function(inputs, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
...@@ -161,12 +163,21 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -161,12 +163,21 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
""" """
if name is None: if name is None:
call_info = None # Determine possible file names
for elem in tb.extract_stack(): source_file = re.sub('\.pyc?', '.py', __file__)
if elem[2] != '<module>': compiled_file = source_file + 'c'
call_info = elem
# Find call to function and step back up the stack one step
stack = tb.extract_stack()
idx = None
for i, elem in enumerate(stack):
if elem[0] == source_file or elem[0] == compiled_file:
idx = i - 1
break break
if call_info is not None:
# Set the name
if idx is not None:
call_info = stack[idx]
name = call_info[0] + ':' + str(call_info[1]) name = call_info[0] + ':' + str(call_info[1])
if updates is None: if updates is None:
......
import unittest
import os
import re
import theano
from theano import tensor
class FunctionName(unittest.TestCase):
def test_function_name(self):
x = tensor.vector('x')
func = theano.function([x], x + 1.)
regex = re.compile(os.path.basename('.*test_function_name.pyc?:13'))
assert(regex.match(func.name) is not None)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论