提交 65465106 authored 作者: Reyhane Askari's avatar Reyhane Askari

changed sync to function

上级 50f66cc6
...@@ -76,7 +76,7 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None, ...@@ -76,7 +76,7 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
def function(inputs, outputs=None, mode=None, updates=None, givens=None, def function(inputs, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, profile=None, rebuild_strict=True, allow_input_downcast=None, profile=None,
on_unused_input=None, sync=False): on_unused_input=None):
""" """
Return a :class:`callable object <theano.compile.function_module.Function>` Return a :class:`callable object <theano.compile.function_module.Function>`
that will calculate `outputs` from `inputs`. that will calculate `outputs` from `inputs`.
...@@ -323,8 +323,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -323,8 +323,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
allow_input_downcast=allow_input_downcast, allow_input_downcast=allow_input_downcast,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
profile=profile, profile=profile,
output_keys=output_keys, output_keys=output_keys)
sync=sync)
# We need to add the flag check_aliased inputs if we have any mutable or # We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs # borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs fn._check_for_aliased_inputs = check_for_aliased_inputs
......
...@@ -1007,6 +1007,13 @@ class Function(object): ...@@ -1007,6 +1007,13 @@ class Function(object):
""" """
return [i.variable for i in self.maker.inputs if i.implicit] return [i.variable for i in self.maker.inputs if i.implicit]
def sync_shared(self):
for i, inp in enumerate(self.input_storage):
if i in self.maker.fgraph.update_mapping.values():
if (hasattr(theano, "gpuarray") and
isinstance(inp.data, pygpu.gpuarray.GpuArray)):
inp.data.sync()
# pickling/deepcopy support for Function # pickling/deepcopy support for Function
def _pickle_Function(f): def _pickle_Function(f):
...@@ -1391,8 +1398,7 @@ class FunctionMaker(object): ...@@ -1391,8 +1398,7 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs, def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function, mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None, profile=None, on_unused_input=None, fgraph=None,
output_keys=None, sync=False): output_keys=None):
self.sync = sync
mode = theano.compile.mode.get_mode(mode) mode = theano.compile.mode.get_mode(mode)
# Assert old way of working isn't used # Assert old way of working isn't used
...@@ -1691,13 +1697,6 @@ class FunctionMaker(object): ...@@ -1691,13 +1697,6 @@ class FunctionMaker(object):
defaults, self.unpack_single, defaults, self.unpack_single,
self.return_none, self.output_keys, self) self.return_none, self.output_keys, self)
if self.sync:
for i, inp in enumerate(input_storage):
if i in self.fgraph.update_mapping.values():
if (hasattr(theano, "gpuarray") and
isinstance(inp.data, pygpu.gpuarray.GpuArray)):
inp.data.sync()
fn.profile = self.profile fn.profile = self.profile
return fn return fn
...@@ -1743,7 +1742,7 @@ def register_checker(checker): ...@@ -1743,7 +1742,7 @@ def register_checker(checker):
def orig_function(inputs, outputs, mode=None, accept_inplace=False, def orig_function(inputs, outputs, mode=None, accept_inplace=False,
name=None, profile=None, on_unused_input=None, name=None, profile=None, on_unused_input=None,
output_keys=None, sync=False): output_keys=None):
""" """
Return a Function that will calculate the outputs from the inputs. Return a Function that will calculate the outputs from the inputs.
...@@ -1814,8 +1813,7 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1814,8 +1813,7 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
accept_inplace=accept_inplace, accept_inplace=accept_inplace,
profile=profile, profile=profile,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
output_keys=output_keys, output_keys=output_keys)
sync=sync)
with theano.configparser.change_flags(compute_test_value="off"): with theano.configparser.change_flags(compute_test_value="off"):
fn = m.create(defaults) fn = m.create(defaults)
finally: finally:
......
...@@ -283,7 +283,7 @@ class Param(In): ...@@ -283,7 +283,7 @@ class Param(In):
def pfunc(params, outputs=None, mode=None, updates=None, givens=None, def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None, output_keys=None, sync=False): profile=None, on_unused_input=None, output_keys=None):
""" """
Function-constructor for graphs with shared variables. Function-constructor for graphs with shared variables.
...@@ -483,7 +483,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -483,7 +483,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
return orig_function(inputs, cloned_outputs, mode, return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, accept_inplace=accept_inplace, name=name,
profile=profile, on_unused_input=on_unused_input, profile=profile, on_unused_input=on_unused_input,
output_keys=output_keys, sync=sync) output_keys=output_keys)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
...@@ -919,8 +919,9 @@ def test_sync(): ...@@ -919,8 +919,9 @@ def test_sync():
updates = [(w, w + T.sum(T.dot(x, w) + updates = [(w, w + T.sum(T.dot(x, w) +
T.dot(5 * x, 2 * w)))] T.dot(5 * x, 2 * w)))]
f = theano.function([x], y, updates=updates, sync=True) f = theano.function([x], y, updates=updates)
g = theano.function([x], y, updates=updates, sync=False) f.sync_shared()
g = theano.function([x], y, updates=updates)
x_ = np.random.rand(100, 300).astype('float32') x_ = np.random.rand(100, 300).astype('float32')
f(x_) f(x_)
g(x_) g(x_)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论