提交 fef231b0 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

path is now treated as a theano variable

上级 211bc60d
...@@ -2029,30 +2029,34 @@ class LoadFromDisk(Op): ...@@ -2029,30 +2029,34 @@ class LoadFromDisk(Op):
""" """
@note: Non-differentiable. @note: Non-differentiable.
""" """
def __init__(self, path, dtype): def __init__(self, dtype, broadcastable):
self.path = path
self.dtype = dtype self.dtype = dtype
self.broadcastable = broadcastable
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (type(self) == type(other) and
self.path == other.path and self.broadcastable == other.broadcastable and
self.dtype == other.dtype) self.dtype == other.dtype)
def __hash__(self): def __hash__(self):
return hash((type(self), self.path, self.dtype)) return hash((type(self), self.dtype, self.broadcastable))
def make_node(self): def make_node(self, path):
return gof.Apply(self, [], [tensor(self.dtype, broadcastable=(False,))]) if isinstance(path, str):
path = Constant(Generic(), path)
return gof.Apply(self, [path], [tensor(self.dtype,
broadcastable=self.broadcastable)])
def perform(self, node, inp, out): def perform(self, node, inp, out):
d = numpy.load(self.path) path = inp[0]
d = numpy.load(path)
out[0][0] = d[d.keys()[0]].astype(self.dtype) out[0][0] = d[d.keys()[0]].astype(self.dtype)
def __str__(self): def __str__(self):
return "Load: %s"%self.path return "Load: %s, %s"%(self.dtype, self.broadcastable)
def load(path, dtype='float64'): def load(path, dtype, broadcastable):
return LoadFromDisk(path, dtype)() return LoadFromDisk(dtype, broadcastable)(path)
########################## ##########################
# Unary Operations # Unary Operations
......
...@@ -3969,10 +3969,11 @@ class T_load_tensor(unittest.TestCase): ...@@ -3969,10 +3969,11 @@ class T_load_tensor(unittest.TestCase):
data = numpy.arange(5) data = numpy.arange(5)
filename = "_load_tensor_test_1.npz" filename = "_load_tensor_test_1.npz"
numpy.savez(filename, data) numpy.savez(filename, data)
x = tensor.load(filename, 'int64') path = Variable(Generic())
x = tensor.load(path, 'int64', (False,))
y = x*2 y = x*2
fn = function([], [y]) fn = function([path], [y])
assert (fn() == data*2).all() assert (fn(filename) == data*2).all()
class test_grad(unittest.TestCase): class test_grad(unittest.TestCase):
class O(gof.op.Op): class O(gof.op.Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论