提交 65465106 authored 作者: Reyhane Askari's avatar Reyhane Askari

changed sync to function

上级 50f66cc6
......@@ -76,7 +76,7 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
def function(inputs, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, profile=None,
on_unused_input=None, sync=False):
on_unused_input=None):
"""
Return a :class:`callable object <theano.compile.function_module.Function>`
that will calculate `outputs` from `inputs`.
......@@ -323,8 +323,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
allow_input_downcast=allow_input_downcast,
on_unused_input=on_unused_input,
profile=profile,
output_keys=output_keys,
sync=sync)
output_keys=output_keys)
# We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs
......
......@@ -1007,6 +1007,13 @@ class Function(object):
"""
return [i.variable for i in self.maker.inputs if i.implicit]
def sync_shared(self):
for i, inp in enumerate(self.input_storage):
if i in self.maker.fgraph.update_mapping.values():
if (hasattr(theano, "gpuarray") and
isinstance(inp.data, pygpu.gpuarray.GpuArray)):
inp.data.sync()
# pickling/deepcopy support for Function
def _pickle_Function(f):
......@@ -1391,8 +1398,7 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None,
output_keys=None, sync=False):
self.sync = sync
output_keys=None):
mode = theano.compile.mode.get_mode(mode)
# Assert old way of working isn't used
......@@ -1691,13 +1697,6 @@ class FunctionMaker(object):
defaults, self.unpack_single,
self.return_none, self.output_keys, self)
if self.sync:
for i, inp in enumerate(input_storage):
if i in self.fgraph.update_mapping.values():
if (hasattr(theano, "gpuarray") and
isinstance(inp.data, pygpu.gpuarray.GpuArray)):
inp.data.sync()
fn.profile = self.profile
return fn
......@@ -1743,7 +1742,7 @@ def register_checker(checker):
def orig_function(inputs, outputs, mode=None, accept_inplace=False,
name=None, profile=None, on_unused_input=None,
output_keys=None, sync=False):
output_keys=None):
"""
Return a Function that will calculate the outputs from the inputs.
......@@ -1814,8 +1813,7 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
accept_inplace=accept_inplace,
profile=profile,
on_unused_input=on_unused_input,
output_keys=output_keys,
sync=sync)
output_keys=output_keys)
with theano.configparser.change_flags(compute_test_value="off"):
fn = m.create(defaults)
finally:
......
......@@ -283,7 +283,7 @@ class Param(In):
def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None, output_keys=None, sync=False):
profile=None, on_unused_input=None, output_keys=None):
"""
Function-constructor for graphs with shared variables.
......@@ -483,7 +483,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name,
profile=profile, on_unused_input=on_unused_input,
output_keys=output_keys, sync=sync)
output_keys=output_keys)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
......@@ -919,8 +919,9 @@ def test_sync():
updates = [(w, w + T.sum(T.dot(x, w) +
T.dot(5 * x, 2 * w)))]
f = theano.function([x], y, updates=updates, sync=True)
g = theano.function([x], y, updates=updates, sync=False)
f = theano.function([x], y, updates=updates)
f.sync_shared()
g = theano.function([x], y, updates=updates)
x_ = np.random.rand(100, 300).astype('float32')
f(x_)
g(x_)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论