提交 07061384 authored 作者: James Bergstra's avatar James Bergstra

merge

......@@ -10,7 +10,8 @@ from function_module import orig_function
from pfunc import pfunc
from numpy import any #for to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inplace=False, name=None):
def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None):
"""
Return a callable object that will calculate `outputs` from `inputs`.
......@@ -31,7 +32,12 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl
and Var2 in each pair must have the same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces
Var1).
Var1).
:type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables.
If False (default), perform them all. Else, perform automatic updates on all Variables
that are neither in "updates" nor in "no_default_updates".
:param name: an optional name for this function. The profile mode will print the time spent in this function.
......@@ -65,4 +71,5 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl
mode=mode,
updates=updates,
givens=givens,
no_default_updates=no_default_updates,
accept_inplace=accept_inplace,name=name)
差异被折叠。
......@@ -28,6 +28,11 @@ class SharedVariable(Variable):
:type: `Container`
"""
# default_update
# If this member is present, its value will be used as the "update" for
# this Variable, unless another update value has been passed to "function",
# or the "no_default_updates" list passed to "function" contains it.
def __init__(self, name, type, value, strict, container=None):
"""
:param name: The name for this variable (see `Variable`).
......
......@@ -562,7 +562,7 @@ using namespace std;
if not d["type"]=="double":d["gemm"]='sgemm_'
if self.imshp != self.imshp_logical or self.kshp != self.kshp_logical:
if verbose:
if self.verbose:
print "return imshp!=imshp_logical or self.kshp != self.kshp_logical shape version"
return _conv_op_code_a % d
......
......@@ -228,7 +228,7 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
# TODO: add the optimization in FAST_COMPILE?
# In the mean time, run it as 'FAST_RUN' instead
mode = theano.compile.mode.get_default_mode()
if mode == 'FAST_COMPILE':
if mode == theano.compile.mode.get_mode('FAST_COMPILE'):
mode = 'FAST_RUN'
rng = numpy.random.RandomState(utt.fetch_seed())
......@@ -327,7 +327,7 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
# TODO: add the optimization in FAST_COMPILE?
# In the mean time, run it as 'FAST_RUN' instead
mode = theano.compile.mode.get_default_mode()
if mode == 'FAST_COMPILE':
if mode == theano.compile.mode.get_mode('FAST_COMPILE'):
mode = 'FAST_RUN'
rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论