提交 ee4eedc6 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/tensor/io.py

上级 c172b4c4
...@@ -11,13 +11,18 @@ import theano ...@@ -11,13 +11,18 @@ import theano
class LoadFromDisk(Op): class LoadFromDisk(Op):
""" """
An operation to load an array from disk An operation to load an array from disk.
See Also See Also
load --------
load
Notes
-----
Non-differentiable.
@note: Non-differentiable.
""" """
__props__ = ("dtype", "broadcastable", "mmap_mode") __props__ = ("dtype", "broadcastable", "mmap_mode")
def __init__(self, dtype, broadcastable, mmap_mode=None): def __init__(self, dtype, broadcastable, mmap_mode=None):
...@@ -53,18 +58,26 @@ def load(path, dtype, broadcastable, mmap_mode=None): ...@@ -53,18 +58,26 @@ def load(path, dtype, broadcastable, mmap_mode=None):
""" """
Load an array from an .npy file. Load an array from an .npy file.
:param path: A Generic symbolic variable, that will contain a string Parameters
:param dtype: The data type of the array to be read. ----------
:param broadcastable: The broadcastable pattern of the loaded array, path
for instance, (False,) for a vector, (False, True) for a column, A Generic symbolic variable, that will contain a string
(False, False) for a matrix. dtype : data-type
:param mmap_mode: How the file will be loaded. None means that the The data type of the array to be read.
data will be copied into an array in memory, 'c' means that the file broadcastable
will be mapped into virtual memory, so only the parts that are The broadcastable pattern of the loaded array, for instance,
needed will be actually read from disk and put into memory. (False,) for a vector, (False, True) for a column,
Other modes supported by numpy.load ('r', 'r+', 'w+') cannot (False, False) for a matrix.
be supported by Theano. mmap_mode
How the file will be loaded. None means that the
data will be copied into an array in memory, 'c' means that the file
will be mapped into virtual memory, so only the parts that are
needed will be actually read from disk and put into memory.
Other modes supported by numpy.load ('r', 'r+', 'w+') cannot
be supported by Theano.
Examples
--------
>>> from theano import * >>> from theano import *
>>> path = Variable(Generic()) >>> path = Variable(Generic())
>>> x = tensor.load(path, 'int64', (False,)) >>> x = tensor.load(path, 'int64', (False,))
...@@ -72,6 +85,7 @@ def load(path, dtype, broadcastable, mmap_mode=None): ...@@ -72,6 +85,7 @@ def load(path, dtype, broadcastable, mmap_mode=None):
>>> fn = function([path], y) >>> fn = function([path], y)
>>> fn("stored-array.npy") >>> fn("stored-array.npy")
array([0, 2, 4, 6, 8], dtype=int64) array([0, 2, 4, 6, 8], dtype=int64)
""" """
return LoadFromDisk(dtype, broadcastable, mmap_mode)(path) return LoadFromDisk(dtype, broadcastable, mmap_mode)(path)
...@@ -91,14 +105,19 @@ else: ...@@ -91,14 +105,19 @@ else:
class MPIRecv(Op): class MPIRecv(Op):
""" """
An operation to asynchronously receive an array to a remote host using MPI An operation to asynchronously receive an array to a remote host using MPI.
See Also See Also
MPIRecv --------
MPIWait MPIRecv
MPIWait
Notes
-----
Non-differentiable.
@note: Non-differentiable.
""" """
__props__ = ("source", "tag", "shape", "dtype") __props__ = ("source", "tag", "shape", "dtype")
def __init__(self, source, tag, shape, dtype): def __init__(self, source, tag, shape, dtype):
...@@ -134,13 +153,18 @@ class MPIRecv(Op): ...@@ -134,13 +153,18 @@ class MPIRecv(Op):
class MPIRecvWait(Op): class MPIRecvWait(Op):
""" """
An operation to wait on a previously received array using MPI An operation to wait on a previously received array using MPI.
See Also See Also
MPIRecv --------
MPIRecv
Notes
-----
Non-differentiable.
@note: Non-differentiable.
""" """
__props__ = ("tag",) __props__ = ("tag",)
def __init__(self, tag): def __init__(self, tag):
...@@ -168,14 +192,19 @@ class MPIRecvWait(Op): ...@@ -168,14 +192,19 @@ class MPIRecvWait(Op):
class MPISend(Op): class MPISend(Op):
""" """
An operation to asynchronously Send an array to a remote host using MPI An operation to asynchronously Send an array to a remote host using MPI.
See Also See Also
MPIRecv --------
MPISendWait MPIRecv
MPISendWait
Notes
-----
Non-differentiable.
@note: Non-differentiable.
""" """
__props__ = ("dest", "tag") __props__ = ("dest", "tag")
def __init__(self, dest, tag): def __init__(self, dest, tag):
...@@ -202,12 +231,16 @@ class MPISend(Op): ...@@ -202,12 +231,16 @@ class MPISend(Op):
class MPISendWait(Op): class MPISendWait(Op):
""" """
An operation to wait on a previously sent array using MPI An operation to wait on a previously sent array using MPI.
See Also
--------
MPISend
See Also: Notes
MPISend -----
Non-differentiable.
@note: Non-differentiable.
""" """
__props__ = ("tag",) __props__ = ("tag",)
...@@ -227,35 +260,35 @@ class MPISendWait(Op): ...@@ -227,35 +260,35 @@ class MPISendWait(Op):
def isend(var, dest, tag): def isend(var, dest, tag):
""" """
Non blocking send Non blocking send.
""" """
return MPISend(dest, tag)(var) return MPISend(dest, tag)(var)
def send(var, dest, tag): def send(var, dest, tag):
""" """
blocking send Blocking send.
""" """
return MPISendWait(tag)(*isend(var, dest, tag)) return MPISendWait(tag)(*isend(var, dest, tag))
def irecv(shape, dtype, source, tag): def irecv(shape, dtype, source, tag):
""" """
non-blocking receive Non-blocking receive.
""" """
return MPIRecv(source, tag, shape, dtype)() return MPIRecv(source, tag, shape, dtype)()
def recv(shape, dtype, source, tag): def recv(shape, dtype, source, tag):
""" """
blocking receive Blocking receive.
""" """
return MPIRecvWait(tag)(*irecv(shape, dtype, source, tag)) return MPIRecvWait(tag)(*irecv(shape, dtype, source, tag))
# Ordering keys for scheduling # Ordering keys for scheduling
def mpi_send_wait_key(a): def mpi_send_wait_key(a):
""" Wait as long as possible on Waits, Start Send/Recvs early """ """Wait as long as possible on Waits, Start Send/Recvs early."""
if isinstance(a.op, (MPIRecvWait, MPISendWait)): if isinstance(a.op, (MPIRecvWait, MPISendWait)):
return 1 return 1
if isinstance(a.op, (MPIRecv, MPISend)): if isinstance(a.op, (MPIRecv, MPISend)):
...@@ -264,7 +297,7 @@ def mpi_send_wait_key(a): ...@@ -264,7 +297,7 @@ def mpi_send_wait_key(a):
def mpi_tag_key(a): def mpi_tag_key(a):
""" Break MPI ties by using the variable tag - prefer lower tags first """ """Break MPI ties by using the variable tag - prefer lower tags first."""
if isinstance(a.op, (MPISend, MPIRecv, MPIRecvWait, MPISendWait)): if isinstance(a.op, (MPISend, MPIRecv, MPIRecvWait, MPISendWait)):
return a.op.tag return a.op.tag
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论