提交 8022ce3e authored 作者: Olivier Breuleux's avatar Olivier Breuleux

some doc for RandomFunction

上级 dbc0ea78
...@@ -10,6 +10,12 @@ from copy import copy ...@@ -10,6 +10,12 @@ from copy import copy
class RandomFunction(gof.Op): class RandomFunction(gof.Op):
def __init__(self, fn, outtype, *args, **kwargs): def __init__(self, fn, outtype, *args, **kwargs):
"""
fn: a random function with the same signature as functions in numpy.random.RandomState
outtype: the type of the output
args: a list of default arguments for the function
kwargs: if the 'inplace' key is there, its value will be used to determine if the op operates inplace or not
"""
self.fn = fn self.fn = fn
self.outtype = outtype self.outtype = outtype
self.args = map(tensor.as_tensor, args) self.args = map(tensor.as_tensor, args)
...@@ -18,6 +24,13 @@ class RandomFunction(gof.Op): ...@@ -18,6 +24,13 @@ class RandomFunction(gof.Op):
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
def make_node(self, r, shape, *args): def make_node(self, r, shape, *args):
"""
in: r -> RandomState (gof.generic),
shape -> lvector
args -> the arguments expected by the numpy function
out: r2 -> the new RandomState (gof.generic)
out -> the random numbers we generated
"""
args = map(tensor.as_tensor, args) args = map(tensor.as_tensor, args)
shape = tensor.as_tensor(shape) shape = tensor.as_tensor(shape)
assert shape.type == tensor.lvector assert shape.type == tensor.lvector
...@@ -52,6 +65,17 @@ class RandomFunction(gof.Op): ...@@ -52,6 +65,17 @@ class RandomFunction(gof.Op):
def random_function(fn, dtype, *rfargs, **rfkwargs): def random_function(fn, dtype, *rfargs, **rfkwargs):
"""
Returns a wrapper around RandomFunction which automatically infers the number
of dimensions of the output from the given shape. If the shape cannot be inferred,
the user can give an integer as first argument, which will be interpreted as the
number of dimensions.
The number of dimensions for the following shape arguments can be inferred:
- shape(x)
- make_lvector(x, y, z, ...)
- constants
"""
def f(ndim, *args, **kwargs): def f(ndim, *args, **kwargs):
if isinstance(ndim, int): if isinstance(ndim, int):
r, shape, args = args[0], args[1], args[2:] r, shape, args = args[0], args[1], args[2:]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论