提交 96220be7 authored 作者: nouiz's avatar nouiz

Merge pull request #421 from goodfeli/fix_defaults

Fix defaults
...@@ -16,7 +16,6 @@ from theano import gof ...@@ -16,7 +16,6 @@ from theano import gof
from theano.gof.python25 import partial from theano.gof.python25 import partial
import mode as mode_module import mode as mode_module
from io import In, SymbolicInput, SymbolicInputKit, SymbolicOutput from io import In, SymbolicInput, SymbolicInputKit, SymbolicOutput
from theano.configdefaults import config
import logging import logging
_logger = logging.getLogger('theano.compile.function_module') _logger = logging.getLogger('theano.compile.function_module')
...@@ -346,6 +345,7 @@ class Function(object): ...@@ -346,6 +345,7 @@ class Function(object):
Initialize attributes. create finder, inv_finder. Initialize attributes. create finder, inv_finder.
""" """
self.fn = fn self.fn = fn
self.input_storage = input_storage self.input_storage = input_storage
self.output_storage = output_storage self.output_storage = output_storage
...@@ -1082,6 +1082,7 @@ class FunctionMaker(object): ...@@ -1082,6 +1082,7 @@ class FunctionMaker(object):
acts as initialization. acts as initialization.
trustme -> disables some exceptions, used internally trustme -> disables some exceptions, used internally
""" """
if input_storage is None: if input_storage is None:
input_storage = [None]*len(self.inputs) input_storage = [None]*len(self.inputs)
input_storage_lists = [] # list of independent one-element lists, will be passed to the linker input_storage_lists = [] # list of independent one-element lists, will be passed to the linker
...@@ -1090,6 +1091,7 @@ class FunctionMaker(object): ...@@ -1090,6 +1091,7 @@ class FunctionMaker(object):
# The following loop is to fill in the input_storage_lists and defaults lists. # The following loop is to fill in the input_storage_lists and defaults lists.
assert len(self.indices) == len(input_storage) assert len(self.indices) == len(input_storage)
for i, ((input, indices, subinputs), input_storage_i) in enumerate(zip(self.indices, input_storage)): for i, ((input, indices, subinputs), input_storage_i) in enumerate(zip(self.indices, input_storage)):
# Replace any default value given as a variable by its container. # Replace any default value given as a variable by its container.
# Note that this makes sense only in the context of shared variables, # Note that this makes sense only in the context of shared variables,
# but for now we avoid dealing directly with them to avoid dependency # but for now we avoid dealing directly with them to avoid dependency
...@@ -1104,13 +1106,38 @@ class FunctionMaker(object): ...@@ -1104,13 +1106,38 @@ class FunctionMaker(object):
if indices is not None: if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.") raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
input_storage_lists.append(input_storage_i.storage) input_storage_lists.append(input_storage_i.storage)
defaults.append((self.required[i],
self.refeed[i], storage = input_storage[i].storage[0]
input_storage_i.storage[0]))
else: else:
# Normal case: one new, independent storage unit # Normal case: one new, independent storage unit
input_storage_lists.append([input_storage_i]) input_storage_lists.append([input_storage_i])
defaults.append((self.required[i], self.refeed[i], input_storage_i))
storage = input_storage_i
required = self.required[i]
refeed = self.refeed[i]
#sanity check-- if an input is required it should not need to be refed
assert not (required and refeed)
#shared variables need neither be input by the user nor refed
if input.shared:
assert not required
assert not refeed
storage = None
#if an input is required, it never need be refed
if required:
storage = None
#make sure that we only store a value if we actually need it
if storage is not None:
assert refeed or not required
defaults.append((required,
refeed,
storage))
# Get a function instance # Get a function instance
...@@ -1125,6 +1152,7 @@ class FunctionMaker(object): ...@@ -1125,6 +1152,7 @@ class FunctionMaker(object):
self.profile.linker_time += linker_time self.profile.linker_time += linker_time
_fn.time_thunks = self.profile.flag_time_thunks _fn.time_thunks = self.profile.flag_time_thunks
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, self.return_none, self) defaults, self.unpack_single, self.return_none, self)
fn.profile = self.profile fn.profile = self.profile
......
...@@ -194,7 +194,12 @@ class In(SymbolicInput): ...@@ -194,7 +194,12 @@ class In(SymbolicInput):
# try to keep it synchronized. # try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None, def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, allow_downcast=None, autoname=True, mutable=None, strict=False, allow_downcast=None, autoname=True,
implicit=None, borrow=None): implicit=None, borrow=None, shared = False):
#if shared, an input's value comes from its persistent storage, not from a default stored
#in the function or from the caller
self.shared = shared
# mutable implies the output can be both aliased to the input and that the input can be # mutable implies the output can be both aliased to the input and that the input can be
# destroyed. borrow simply implies the output can be aliased to the input. Thus # destroyed. borrow simply implies the output can be aliased to the input. Thus
......
...@@ -427,14 +427,18 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -427,14 +427,18 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
i.variable = iv i.variable = iv
for sv in shared_inputs: for sv in shared_inputs:
#pass value of None here
#value will be stored in the resulting functions' defaults list
#but since the value of shared variables never needs to be refed, it is not needed
if sv in update_d: if sv in update_d:
si = In(variable=sv, value=sv.container, mutable=True, si = In(variable=sv, value = sv.container, mutable=True,
borrow=True, update=update_d[sv]) borrow=True, update=update_d[sv], shared = True)
else: else:
si = In(variable=sv, value=sv.container, si = In(variable=sv, value = sv.container,
mutable=False, borrow=True) mutable=False, borrow=True, shared = True)
inputs.append(si) inputs.append(si)
return orig_function(inputs, cloned_outputs, mode, return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile) accept_inplace=accept_inplace, name=name, profile=profile)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论