提交 6092703d authored 作者: Olivier Breuleux's avatar Olivier Breuleux

named a bunch of ops

上级 c1f0ce3a
...@@ -274,9 +274,9 @@ class Elemwise(Op): ...@@ -274,9 +274,9 @@ class Elemwise(Op):
def __str__(self): def __str__(self):
if self.name is None: if self.name is None:
if self.inplace_pattern: if self.inplace_pattern:
return "Broadcast{%s}%s" % (self.scalar_op, str(self.inplace_pattern)) return "Elemwise{%s}%s" % (self.scalar_op, str(self.inplace_pattern))
else: else:
return "Broadcast{%s}" % (self.scalar_op) return "Elemwise{%s}" % (self.scalar_op)
else: else:
return self.name return self.name
...@@ -628,5 +628,11 @@ class Sum(CAReduce): ...@@ -628,5 +628,11 @@ class Sum(CAReduce):
i += 1 i += 1
return Elemwise(scalar.second)(x, DimShuffle(gz.type.broadcastable, new_dims)(gz)), return Elemwise(scalar.second)(x, DimShuffle(gz.type.broadcastable, new_dims)(gz)),
def __str__(self):
if self.axis is None:
return "Sum"
else:
return "Sum{%s}" % ", ".join(map(str, self.axis))
...@@ -388,6 +388,7 @@ class _tensor_py_operators: ...@@ -388,6 +388,7 @@ class _tensor_py_operators:
raise TypeError('Tensor does not support iteration') raise TypeError('Tensor does not support iteration')
class TensorResult(Result, _tensor_py_operators): class TensorResult(Result, _tensor_py_operators):
pass pass
...@@ -482,7 +483,7 @@ class Shape(Op): ...@@ -482,7 +483,7 @@ class Shape(Op):
""" """
def make_node(self, x): def make_node(self, x):
x = as_tensor(x) x = as_tensor(x)
return Apply(self, [x], [ivector()]) return Apply(self, [x], [lvector()])
def perform(self, node, (x, ), (out, )): def perform(self, node, (x, ), (out, )):
out[0] = numpy.asarray(x.shape) out[0] = numpy.asarray(x.shape)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -629,59 +630,13 @@ class TransposeInplace(Op): ...@@ -629,59 +630,13 @@ class TransposeInplace(Op):
%(z)s = transposed; %(z)s = transposed;
""" % locals() """ % locals()
def __str__(self):
return "TransposeView"
transpose_inplace = TransposeInplace() transpose_inplace = TransposeInplace()
def transpose(x, **kwargs): def transpose(x, **kwargs):
return transpose_inplace(tensor_copy(x), **kwargs) return transpose_inplace(tensor_copy(x), **kwargs)
# class Subtensor_dx(Op, Viewer):
# """Return a tensor full of zeros, except for what was sliced from x by
# Subtensor.
# @todo: pass the shape of x, rather than x itself.
# @todo: add support for advanced tensor indexing (breaks current perform
# implementation).
# """
# def __init__(self, inputs, idx_list, **kwargs):
# Op.__init__(self, **kwargs)
# self.inputs = inputs
# self.outputs = [Tensor(inputs[0].dtype, inputs[0].broadcastable)]
# self.idx_list = idx_list
# def perform(self):
# x = self.inputs[0]
# gz = self.inputs[-1]
# cdata = []
# for c in self.idx_list:
# if isinstance(c, slice):
# if c.start is None: start = None
# else: start = self.inputs[c.start].data
# if c.stop is None: stop = None
# else: stop = self.inputs[c.stop].data
# if c.step is None: step = None
# else: step = self.inputs[c.step].data
# cdata.append(slice(start, stop, step))
# else:
# d = self.inputs[c].data
# assert 'int' in str(d.dtype)
# cdata.append(d)
# if len(cdata) > 1:
# cdata = tuple(cdata) #there's a diff between tuple and list here...
# else:
# cdata = cdata[0]
# #print cdata
# #print gz.data
# gx = numpy.zeros_like(x.data)
# gx[cdata] = gz.data
# #print gx
# self.outputs[0].data = gx
# def clone_with_new_inputs(self, *new_inputs):
# assert len(self.inputs) == len(new_inputs)
# return Subtensor_dx(new_inputs, self.idx_list)
class Subtensor(Op): class Subtensor(Op):
...@@ -789,7 +744,7 @@ class Subtensor(Op): ...@@ -789,7 +744,7 @@ class Subtensor(Op):
cdata = tuple(map(convert, self.idx_list)) cdata = tuple(map(convert, self.idx_list))
if len(cdata) == 1: if len(cdata) == 1:
cdata = cdata[0] cdata = cdata[0]
out[0] = x.__getitem__(cdata) out[0] = numpy.asarray(x.__getitem__(cdata))
def grad(self, inputs, (gz,)): def grad(self, inputs, (gz,)):
x = inputs[0] x = inputs[0]
...@@ -803,6 +758,16 @@ class Subtensor(Op): ...@@ -803,6 +758,16 @@ class Subtensor(Op):
# FIXME: this doesn't work if there are slices in the list because for some mysterious reason slice is unhashable # FIXME: this doesn't work if there are slices in the list because for some mysterious reason slice is unhashable
return hash(tuple(self.idx_list)) return hash(tuple(self.idx_list))
def __str__(self):
indices = []
for entry in self.idx_list:
if isinstance(entry, slice):
indices.append(":".join("" if x is None else str(x) for x in [entry.start, entry.stop, entry.step]))
else:
indices.append(str(entry))
return "%s{%s}" % (self.__class__.__name__, ", ".join(indices))
class SetSubtensor(Subtensor): class SetSubtensor(Subtensor):
view_map = {} view_map = {}
destroy_map = {0: [0]} destroy_map = {0: [0]}
...@@ -942,6 +907,8 @@ class Dot(Op): ...@@ -942,6 +907,8 @@ class Dot(Op):
z[0] = numpy.dot(x, y) z[0] = numpy.dot(x, y)
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
return dot(gz, y.T), dot(x.T, gz) return dot(gz, y.T), dot(x.T, gz)
def __str__(self):
return "Dot"
dot = Dot() dot = Dot()
class Gemm(Op): class Gemm(Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论