提交 da3060fe authored 作者: Frederic's avatar Frederic

pep8

上级 047cb7d1
...@@ -14,19 +14,24 @@ def get_diagonal_subtensor_view(x, i0, i1): ...@@ -14,19 +14,24 @@ def get_diagonal_subtensor_view(x, i0, i1):
xview.strides = strides xview.strides = strides
return xview return xview
class DiagonalSubtensor(Op): class DiagonalSubtensor(Op):
def __init__(self, inplace): def __init__(self, inplace):
self.inplace = inplace self.inplace = inplace
if inplace: if inplace:
self.view_map = {0:[0]} self.view_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash((type(self), self.inplace)) return hash((type(self), self.inplace))
def make_node(self, x, i0, i1): def make_node(self, x, i0, i1):
_i0 = tensor.as_tensor_variable(i0) _i0 = tensor.as_tensor_variable(i0)
_i1 = tensor.as_tensor_variable(i1) _i1 = tensor.as_tensor_variable(i1)
return Apply(self, [x, _i0, _i1], [x.type()]) return Apply(self, [x, _i0, _i1], [x.type()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
xview = get_diagonal_subtensor_view(*inputs) xview = get_diagonal_subtensor_view(*inputs)
if self.inplace: if self.inplace:
...@@ -45,19 +50,24 @@ class DiagonalSubtensor(Op): ...@@ -45,19 +50,24 @@ class DiagonalSubtensor(Op):
diagonal_subtensor = DiagonalSubtensor(False) diagonal_subtensor = DiagonalSubtensor(False)
class IncDiagonalSubtensor(Op): class IncDiagonalSubtensor(Op):
def __init__(self, inplace): def __init__(self, inplace):
self.inplace = inplace self.inplace = inplace
if inplace: if inplace:
self.destroy_map = {0:[0]} self.destroy_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash((type(self), self.inplace)) return hash((type(self), self.inplace))
def make_node(self, x, i0, i1, amt): def make_node(self, x, i0, i1, amt):
_i0 = tensor.as_tensor_variable(i0) _i0 = tensor.as_tensor_variable(i0)
_i1 = tensor.as_tensor_variable(i1) _i1 = tensor.as_tensor_variable(i1)
return Apply(self, [x, _i0, _i1, amt], [x.type()]) return Apply(self, [x, _i0, _i1, amt], [x.type()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
x, i0, i1, amt = inputs x, i0, i1, amt = inputs
if not self.inplace: if not self.inplace:
...@@ -77,9 +87,10 @@ class IncDiagonalSubtensor(Op): ...@@ -77,9 +87,10 @@ class IncDiagonalSubtensor(Op):
return rval return rval
inc_diagonal_subtensor = IncDiagonalSubtensor(False) inc_diagonal_subtensor = IncDiagonalSubtensor(False)
def conv3d(signals, filters, def conv3d(signals, filters,
signals_shape=None, filters_shape=None, signals_shape=None, filters_shape=None,
border_mode='valid', subsample=(1,1,1), **kwargs): border_mode='valid', subsample=(1, 1, 1), **kwargs):
""" """
Convolve spatio-temporal filters with a movie. Convolve spatio-temporal filters with a movie.
...@@ -101,8 +112,6 @@ def conv3d(signals, filters, ...@@ -101,8 +112,6 @@ def conv3d(signals, filters,
_signals_shape_5d = signals.shape if signals_shape is None else signals_shape _signals_shape_5d = signals.shape if signals_shape is None else signals_shape
_filters_shape_5d = filters.shape if filters_shape is None else filters_shape _filters_shape_5d = filters.shape if filters_shape is None else filters_shape
_signals_shape_4d = ( _signals_shape_4d = (
_signals_shape_5d[0] * _signals_shape_5d[1], _signals_shape_5d[0] * _signals_shape_5d[1],
_signals_shape_5d[2], _signals_shape_5d[2],
...@@ -124,7 +133,7 @@ def conv3d(signals, filters, ...@@ -124,7 +133,7 @@ def conv3d(signals, filters,
filters.reshape(_filters_shape_4d), filters.reshape(_filters_shape_4d),
image_shape=_signals_shape_4d, image_shape=_signals_shape_4d,
filter_shape=_filters_shape_4d, filter_shape=_filters_shape_4d,
border_mode = border_mode[1]) #ignoring border_mode[2] border_mode = border_mode[1]) # ignoring border_mode[2]
# reshape the output to restore its original size # reshape the output to restore its original size
# shape = Ns, Ts, Nf, Tf, W-Wf+1, H-Hf+1 # shape = Ns, Ts, Nf, Tf, W-Wf+1, H-Hf+1
...@@ -139,10 +148,10 @@ def conv3d(signals, filters, ...@@ -139,10 +148,10 @@ def conv3d(signals, filters,
)) ))
elif border_mode[1] == 'full': elif border_mode[1] == 'full':
out_tmp = out_4d.reshape(( out_tmp = out_4d.reshape((
_signals_shape_5d[0], #Ns _signals_shape_5d[0], # Ns
_signals_shape_5d[1], #Ts _signals_shape_5d[1], # Ts
_filters_shape_5d[0], #Nf _filters_shape_5d[0], # Nf
_filters_shape_5d[1], #Tf _filters_shape_5d[1], # Tf
_signals_shape_5d[3] + _filters_shape_5d[3] - 1, _signals_shape_5d[3] + _filters_shape_5d[3] - 1,
_signals_shape_5d[4] + _filters_shape_5d[4] - 1, _signals_shape_5d[4] + _filters_shape_5d[4] - 1,
)) ))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论