提交 5ec1191f authored 作者: Frederic Bastien's avatar Frederic Bastien

After talk with James, added an option to function(rebuild_strict=True). When…

After talk with James, added an option to function(rebuild_strict=True). When False allow givens to change the type of inputs to rebuild the graph and make it work with inputs on the gpu and many other case. We did not test all others cases.
上级 858cf0a3
...@@ -11,7 +11,8 @@ from pfunc import pfunc ...@@ -11,7 +11,8 @@ from pfunc import pfunc
from numpy import any #for to work in python 2.4 from numpy import any #for to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=[], givens=[], def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None): no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict = True):
""" """
Return a callable object that will calculate `outputs` from `inputs`. Return a callable object that will calculate `outputs` from `inputs`.
...@@ -45,6 +46,15 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], ...@@ -45,6 +46,15 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
:returns: a callable object that will compute the outputs (given the inputs) :returns: a callable object that will compute the outputs (given the inputs)
and update the implicit function arguments according to the `updates`. and update the implicit function arguments according to the `updates`.
:param rebuild_strict: Allow givens to change the type of the inputs of the ops. This
allow to change cpu variables with gpu variables. This could
also serve to change vector to matrix to create a minibatch version
of the function in some case(not tested) and to make the function
work with sparse type(not tested) or complex type(not tested).
WARNING: not all ops can be rebuild with inputs of other type!
In that case an error will be probably raised(not tested).
STRONGLY suggested: test the generated graph in DebugMode!
:note: Regarding givens: Be careful to make sure that these substitutions are :note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from another expression is undefined. Replacements specified with givens are different from
...@@ -72,4 +82,5 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], ...@@ -72,4 +82,5 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
updates=updates, updates=updates,
givens=givens, givens=givens,
no_default_updates=no_default_updates, no_default_updates=no_default_updates,
accept_inplace=accept_inplace,name=name) accept_inplace=accept_inplace,name=name,
rebuild_strict=rebuild_strict)
...@@ -34,7 +34,7 @@ class Param(object): ...@@ -34,7 +34,7 @@ class Param(object):
self.implicit = implicit self.implicit = implicit
def pfunc(params, outputs=None, mode=None, updates=[], givens=[], def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None): no_default_updates=False, accept_inplace=False, name=None, rebuild_strict = True):
"""Function-constructor for graphs with shared variables. """Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances. :type params: list of either Variable or Param instances.
...@@ -145,7 +145,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -145,7 +145,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
if a not in clone_d: if a not in clone_d:
for i in a.inputs: for i in a.inputs:
clone_v_get_shared_updates(i) clone_v_get_shared_updates(i)
clone_d[a] = a.clone_with_new_inputs([clone_d[i] for i in a.inputs]) clone_d[a] = a.clone_with_new_inputs([clone_d[i] for i in a.inputs],
strict = rebuild_strict)
for old_o, new_o in zip(a.outputs, clone_d[a].outputs): for old_o, new_o in zip(a.outputs, clone_d[a].outputs):
clone_d.setdefault(old_o, new_o) clone_d.setdefault(old_o, new_o)
return clone_d[a] return clone_d[a]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论