提交 d2975234 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Merge perform_noview and perform_view since they do almost the same thing.

上级 5e80c182
...@@ -394,9 +394,6 @@ class ExtractDiag(Op): ...@@ -394,9 +394,6 @@ class ExtractDiag(Op):
self.view = view self.view = view
if self.view: if self.view:
self.view_map = {0:[0]} self.view_map = {0:[0]}
self.perform = self.perform_view
else:
self.perform = self.perform_noview
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.view == other.view return type(self) == type(other) and self.view == other.view
def __hash__(self): def __hash__(self):
...@@ -406,20 +403,16 @@ class ExtractDiag(Op): ...@@ -406,20 +403,16 @@ class ExtractDiag(Op):
if x.type.ndim != 2: if x.type.ndim != 2:
raise TypeError('ExtractDiag only works on matrices', _x) raise TypeError('ExtractDiag only works on matrices', _x)
return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)]) return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)])
def perform_noview(self, node, (x,), (z,)): def perform(self, node, (x,), (z,)):
#for some reason numpy.diag(x) is really slow #for some reason numpy.diag(x) is really slow
N,M = x.shape N,M = x.shape
assert N==M assert N==M
rval = x[0] rval = x[0]
rval.strides = (x.strides[0]+x.strides[1],) rval.strides = (x.strides[0]+x.strides[1],)
z[0] = rval.copy() if self.view:
def perform_view(self, node, (x,), (z,)):
N,M = x.shape
a,b = x.strides
assert N==M
rval = x[0]
rval.strides = a+b,
z[0] = rval z[0] = rval
else:
z[0] = rval.copy()
def __str__(self): def __str__(self):
return 'ExtractDiag{view=%s}'%self.view return 'ExtractDiag{view=%s}'%self.view
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论