提交 4b29e343 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed Tensor

上级 f86767b2
......@@ -417,49 +417,6 @@ class CLinker(Linker):
except AbstractFunctionError: pass
return ret
# def make_function(self, in_order, out_order):
# nin = len(self.inputs)
# nout = len(self.outputs)
# if nin != len(in_order):
# raise TypeError("Wrong number of inputs.")
# if nout != len(out_order):
# raise TypeError("Wrong number of outputs.")
# in_storage = []
# out_storage = []
# cthunk_in_args = [None] * nin
# cthunk_out_args = [None] * nout
# for result in in_order:
# idx = self.inputs.index(result)
# storage = [None]
# cthunk_in_args[idx] = storage
# in_storage.append(storage)
# for result in out_order:
# idx = self.outputs.index(result)
# storage = [None]
# cthunk_out_args[idx] = storage
# out_storage.append(storage)
# for arg in cthunk_in_args + cthunk_out_args:
# if arg is None:
# raise Exception("The inputs or outputs are underspecified.")
# error_storage = [None, None, None]
# cthunk = self.cthunk_factory(error_storage, cthunk_in_args, cthunk_out_args)
# def execute(*args):
# for arg, storage in zip(args, in_storage):
# storage[0] = arg
# failure = cutils.run_cthunk(cthunk)
# if failure:
# raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
# return utils.to_return_values([storage[0] for storage in out_storage])
# return execute
def __compile__(self, inplace = False):
if inplace:
in_results = self.inputs
......
......@@ -29,9 +29,8 @@ class Grad(object):
def __getitem__(self, item):
"""Map item to its id and retrieve it."""
key = core.wrap(item)
try:
return self.map[key]
return self.map[item]
except KeyError:
return Undefined
......@@ -60,16 +59,16 @@ class Grad(object):
# nothing to do
return
if r.data is not None and dr.data is not None:
if not hasattr(r, 'shape'):
raise ValueError(('Grad::add r lacks shape: type=',
type(r)))
if not hasattr(dr, 'shape'):
raise ValueError(('Grad::add dr lacks shape: type=',
type(dr)))
if r.shape != dr.shape:
raise ValueError(('Grad::add r, dr shape mismatch',
r.shape, dr.shape))
# if r.data is not None and dr.data is not None:
# if not hasattr(r, 'shape'):
# raise ValueError(('Grad::add r lacks shape: type=',
# type(r)))
# if not hasattr(dr, 'shape'):
# raise ValueError(('Grad::add dr lacks shape: type=',
# type(dr)))
# if r.shape != dr.shape:
# raise ValueError(('Grad::add r, dr shape mismatch',
# r.shape, dr.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.state is gof.result.Computed:
......@@ -102,14 +101,12 @@ class Grad(object):
"""
if not maybe_redo and self.did_bprop:
raise Exception('bprop has already been done. Consider calling with maybe_redo=True.')
core.build_mode()
try:
outputs = self.outputs
inputs = gof.graph.inputs(outputs)
for op in gof.graph.io_toposort(inputs, outputs).__reversed__():
op.update_gradient(self)
finally:
core.pop_mode()
self.did_bprop = True
def __call__(self, item):
......@@ -121,8 +118,7 @@ class Grad(object):
if not self.did_bprop:
raise Exception('Grad.__call__ only makes sense after a bprop')
rval = self[item]
if rval is not Undefined \
and core.current_mode() == 'build_eval':
if rval is not Undefined:
compute_from([rval], self._compute_history)
return rval
......
......@@ -6,27 +6,31 @@ from gof import ResultBase
from gof import Op
class NumpyR(ResultBase):
def tensor(data, name = None):
return Tensor(data.dtype, [0]*len(data.shape), data, name)
def __init__(self, dtype, nd, name=None):
self.nd = nd
def _broadcastable_pattern(pattern):
def factory(data = None, name = None):
if data: assert len(data.shape) == len(pattern)
return Tensor(data.dtype, pattern, data, name)
matrix = _broadcastable_pattern([0, 0])
row = _broadcastable_pattern([1, 0])
col = _broadcastable_pattern([0, 1])
class Tensor(ResultBase):
def __init__(self, dtype, broadcastable, data=None, name=None):
self.broadcastable = broadcastable
self.dtype = dtype
ResultBase.__init__(self, role = None, data = None, name = name)
def validate(self, data):
if not isinstance(data, numpy.ndarray):
raise TypeError("Expected ndarray instance.")
elif not len(data.shape) == self.nd:
raise TypeError("Expected ndarray with %i dimensions." % self.nd)
elif not str(data.dtype) == self.dtype:
raise TypeError("Expected ndarray with data type %i." % self.dtype)
# def to_c_type(self, dtype):
# if dtype == "float64":
# return "double"
# else:
# raise TypeError("Cannot translate dtype to C.")
def filter(self, data):
arr = numpy.asarray(data, dtype = self.dtype)
for b, s in zip(self.broadcastable, arr.shape):
assert not b or s == 1
return arr
def c_declare(self):
return """
......@@ -64,35 +68,48 @@ class NumpyR(ResultBase):
return []
def __copy__(self):
cpy = self.__class__(self.dtype, self.nd, self.name)
"""
Returns a copy of this Tensor. If there is data stored inside it, it is also copied.
"""
cpy = self.__class__(self.dtype, self.broadcastable, None, self.name)
cpy.data = copy(self.data)
return cpy
def TheanoOp(Op):
def TensorOp(Op):
nin = -1
nout = 1
def __init__(self, *inputs):
def wrap_as_tensor(x):
if isinstance(x, Tensor):
return x
else:
return Tensor(x)
inputs = map(wrap_as_tensor, inputs)
if self.nin >= 0:
if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)") \
% (self, len(inputs), self.nin)
i_nds = [getattr(input, 'nd', None) for input in inputs]
i_broadcastables = [getattr(input, 'broadcastable', None) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_nds = self.propagate_nd(*i_nds)
o_dtypes = self.propagate_dtypes(*i_dtypes)
return [NumpyR(nd, dtype) for nd, dtype in zip(o_nds, o_dtypes)]
o_broadcastables = utils.from_return_values(self.propagate_broadcastable(*i_broadcastables))
o_dtypes = utils.from_return_values(self.propagate_dtype(*i_dtypes))
self.inputs = inputs
self.outputs = [Tensor(dtype, broadcastable) for broadcastable, dtype in zip(o_broadcastables, o_dtypes)]
def propagate_nds(self, *inputs):
def propagate_broadcastable(self, *inputs):
raise AbstractFunctionError()
def propagate_dtypes(self, *inputs):
def propagate_dtype(self, *inputs):
raise AbstractFunctionError()
def impl(self, *inputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论