提交 616f8d37 authored 作者: Reyhane Askari's avatar Reyhane Askari

sync attribute added to theano function

上级 e89aee9d
...@@ -76,7 +76,7 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None, ...@@ -76,7 +76,7 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
def function(inputs, outputs=None, mode=None, updates=None, givens=None, def function(inputs, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, profile=None, rebuild_strict=True, allow_input_downcast=None, profile=None,
on_unused_input=None): on_unused_input=None, sync=False):
""" """
Return a :class:`callable object <theano.compile.function_module.Function>` Return a :class:`callable object <theano.compile.function_module.Function>`
that will calculate `outputs` from `inputs`. that will calculate `outputs` from `inputs`.
...@@ -323,7 +323,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -323,7 +323,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
allow_input_downcast=allow_input_downcast, allow_input_downcast=allow_input_downcast,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
profile=profile, profile=profile,
output_keys=output_keys) output_keys=output_keys,
sync=sync)
# We need to add the flag check_aliased inputs if we have any mutable or # We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs # borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs fn._check_for_aliased_inputs = check_for_aliased_inputs
......
...@@ -1390,7 +1390,8 @@ class FunctionMaker(object): ...@@ -1390,7 +1390,8 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs, def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function, mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None, profile=None, on_unused_input=None, fgraph=None,
output_keys=None): output_keys=None, sync=False):
self.sync = sync
mode = theano.compile.mode.get_mode(mode) mode = theano.compile.mode.get_mode(mode)
# Assert old way of working isn't used # Assert old way of working isn't used
...@@ -1688,6 +1689,12 @@ class FunctionMaker(object): ...@@ -1688,6 +1689,12 @@ class FunctionMaker(object):
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, defaults, self.unpack_single,
self.return_none, self.output_keys, self) self.return_none, self.output_keys, self)
if self.sync:
for i, inp in enumerate(input_storage):
if i in self.fgraph.update_mapping.values():
inp.data.sync()
fn.profile = self.profile fn.profile = self.profile
return fn return fn
...@@ -1733,7 +1740,7 @@ def register_checker(checker): ...@@ -1733,7 +1740,7 @@ def register_checker(checker):
def orig_function(inputs, outputs, mode=None, accept_inplace=False, def orig_function(inputs, outputs, mode=None, accept_inplace=False,
name=None, profile=None, on_unused_input=None, name=None, profile=None, on_unused_input=None,
output_keys=None): output_keys=None, sync=False):
""" """
Return a Function that will calculate the outputs from the inputs. Return a Function that will calculate the outputs from the inputs.
...@@ -1804,7 +1811,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1804,7 +1811,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
accept_inplace=accept_inplace, accept_inplace=accept_inplace,
profile=profile, profile=profile,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
output_keys=output_keys) output_keys=output_keys,
sync=sync)
with theano.configparser.change_flags(compute_test_value="off"): with theano.configparser.change_flags(compute_test_value="off"):
fn = m.create(defaults) fn = m.create(defaults)
finally: finally:
......
...@@ -283,7 +283,7 @@ class Param(In): ...@@ -283,7 +283,7 @@ class Param(In):
def pfunc(params, outputs=None, mode=None, updates=None, givens=None, def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None, output_keys=None): profile=None, on_unused_input=None, output_keys=None, sync=False):
""" """
Function-constructor for graphs with shared variables. Function-constructor for graphs with shared variables.
...@@ -483,7 +483,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -483,7 +483,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
return orig_function(inputs, cloned_outputs, mode, return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, accept_inplace=accept_inplace, name=name,
profile=profile, on_unused_input=on_unused_input, profile=profile, on_unused_input=on_unused_input,
output_keys=output_keys) output_keys=output_keys, sync=sync)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论