Unverified 提交 8cdc7074 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Remove tensor/io.py (#1766)

上级 cc674a11
...@@ -131,7 +131,6 @@ from pytensor.tensor.basic import * ...@@ -131,7 +131,6 @@ from pytensor.tensor.basic import *
from pytensor.tensor.blas import batched_dot, batched_tensordot from pytensor.tensor.blas import batched_dot, batched_tensordot
from pytensor.tensor.extra_ops import * from pytensor.tensor.extra_ops import *
from pytensor.tensor.interpolate import interp, interpolate1d from pytensor.tensor.interpolate import interp, interpolate1d
from pytensor.tensor.io import *
from pytensor.tensor.math import * from pytensor.tensor.math import *
from pytensor.tensor.pad import pad from pytensor.tensor.pad import pad
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
......
from pathlib import Path
import numpy as np
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.op import Op
from pytensor.link.c.type import Generic
from pytensor.tensor.type import tensor
class LoadFromDisk(Op):
"""
An operation to load an array from disk.
See Also
--------
load
Notes
-----
Non-differentiable.
"""
__props__ = ("dtype", "shape", "mmap_mode")
def __init__(self, dtype, shape, mmap_mode=None):
self.dtype = np.dtype(dtype).name
self.shape = shape
if mmap_mode not in (None, "c"):
raise ValueError(
"The only supported values for mmap_mode "
f"are None and 'c', got {mmap_mode}"
)
self.mmap_mode = mmap_mode
def make_node(self, path):
if isinstance(path, str):
path = Constant(Generic(), path)
return Apply(self, [path], [tensor(dtype=self.dtype, shape=self.shape)])
def perform(self, node, inp, out):
path = Path(inp[0])
if path.suffix != ".npy":
raise ValueError(f"Expected a .npy file, got {path} instead")
result = np.load(path, mmap_mode=self.mmap_mode)
if result.dtype != self.dtype:
raise TypeError(
f"Expected an array of type {self.dtype}, got {result.dtype} instead"
)
out[0][0] = result
def __str__(self):
return (
f"Load{{dtype: {self.dtype}, shape: {self.shape}, mmep: {self.mmap_mode}}}"
)
def load(path, dtype, shape, mmap_mode=None):
"""
Load an array from an .npy file.
Parameters
----------
path
A Generic symbolic variable, that will contain a string
dtype : data-type
The data type of the array to be read.
shape
The static shape information of the loaded array.
mmap_mode
How the file will be loaded. None means that the
data will be copied into an array in memory, 'c' means that the file
will be mapped into virtual memory, so only the parts that are
needed will be actually read from disk and put into memory.
Other modes supported by numpy.load ('r', 'r+', 'w+') cannot
be supported by PyTensor.
Examples
--------
>>> from pytensor import *
>>> path = Variable(Generic(), None)
>>> x = tensor.load(path, "int64", (None,))
>>> y = x * 2
>>> fn = function([path], y)
>>> fn("stored-array.npy") # doctest: +SKIP
array([0, 2, 4, 6, 8], dtype=int64)
"""
return LoadFromDisk(dtype, shape, mmap_mode)(path)
__all__ = ["load"]
import numpy as np
import pytest
import pytensor
from pytensor import function
from pytensor.graph.basic import Variable
from pytensor.link.c.type import Generic
from pytensor.tensor.io import load
class TestLoadTensor:
def setup_method(self):
self.data = np.arange(5, dtype=np.int32)
self.filename = pytensor.config.compiledir / "_test.npy"
np.save(self.filename, self.data)
def test_basic(self):
path = Variable(Generic(), None)
# Not specifying mmap_mode defaults to None, and the data is
# copied into main memory
x = load(path, "int32", (None,))
y = x * 2
fn = function([path], y)
assert (fn(self.filename) == (self.data * 2)).all()
def test_invalid_modes(self):
# Modes 'r+', 'r', and 'w+' cannot work with PyTensor, becausei
# the output array may be modified inplace, and that should not
# modify the original file.
path = Variable(Generic(), None)
for mmap_mode in ("r+", "r", "w+", "toto"):
with pytest.raises(ValueError):
load(path, "int32", (None,), mmap_mode)
def test_copy_on_write(self):
path = Variable(Generic(), None)
# 'c' means "copy-on-write", which allow the array to be overwritten
# by an inplace Op in the graph, without modifying the underlying
# file.
x = load(path, "int32", (None,), "c")
# x ** 2 has been chosen because it will work inplace.
y = (x**2).sum()
fn = function([path], y)
# Call fn() twice, to check that inplace ops do not cause trouble
assert (fn(self.filename) == (self.data**2).sum()).all()
assert (fn(self.filename) == (self.data**2).sum()).all()
def test_memmap(self):
path = Variable(Generic(), None)
x = load(path, "int32", (None,), mmap_mode="c")
fn = function([path], x)
assert isinstance(fn(self.filename), np.memmap)
def teardown_method(self):
(pytensor.config.compiledir / "_test.npy").unlink()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论