提交 85bb8b32 authored 作者: James Bergstra's avatar James Bergstra

Merge pull request #1 from nouiz/master

[MAIN] Update grad() method to don't return None.
from theano.gradient import DisconnectedType
from theano.gof import Op, Apply from theano.gof import Op, Apply
from theano import tensor from theano import tensor
def get_diagonal_subtensor_view(x, i0, i1): def get_diagonal_subtensor_view(x, i0, i1):
if x.shape[i0] < x.shape[i1]: if x.shape[i0] < x.shape[i1]:
raise NotImplementedError('is this allowed?') raise NotImplementedError('is this allowed?')
...@@ -12,44 +14,60 @@ def get_diagonal_subtensor_view(x, i0, i1): ...@@ -12,44 +14,60 @@ def get_diagonal_subtensor_view(x, i0, i1):
xview.strides = strides xview.strides = strides
return xview return xview
class DiagonalSubtensor(Op): class DiagonalSubtensor(Op):
def __init__(self, inplace): def __init__(self, inplace):
self.inplace = inplace self.inplace = inplace
if inplace: if inplace:
self.view_map = {0:[0]} self.view_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash((type(self), self.inplace)) return hash((type(self), self.inplace))
def make_node(self, x, i0, i1): def make_node(self, x, i0, i1):
_i0 = tensor.as_tensor_variable(i0) _i0 = tensor.as_tensor_variable(i0)
_i1 = tensor.as_tensor_variable(i1) _i1 = tensor.as_tensor_variable(i1)
return Apply(self, [x, _i0, _i1], [x.type()]) return Apply(self, [x, _i0, _i1], [x.type()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
xview = get_diagonal_subtensor_view(*inputs) xview = get_diagonal_subtensor_view(*inputs)
if self.inplace: if self.inplace:
output_storage[0][0] = xview output_storage[0][0] = xview
else: else:
output_storage[0][0] = xview.copy() output_storage[0][0] = xview.copy()
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
z = tensor.zeros_like(inputs[0]) z = tensor.zeros_like(inputs[0])
gx = inc_diagonal_subtensor(z, inputs[1], inputs[2], g_outputs[0]) gx = inc_diagonal_subtensor(z, inputs[1], inputs[2], g_outputs[0])
return [gx] + [None] * (len(inputs)-1) return [gx, DisconnectedType()(), DisconnectedType()()]
def connection_pattern(self, node):
rval = [[True], [False], [False]]
return rval
diagonal_subtensor = DiagonalSubtensor(False) diagonal_subtensor = DiagonalSubtensor(False)
class IncDiagonalSubtensor(Op): class IncDiagonalSubtensor(Op):
def __init__(self, inplace): def __init__(self, inplace):
self.inplace = inplace self.inplace = inplace
if inplace: if inplace:
self.destroy_map = {0:[0]} self.destroy_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash((type(self), self.inplace)) return hash((type(self), self.inplace))
def make_node(self, x, i0, i1, amt): def make_node(self, x, i0, i1, amt):
_i0 = tensor.as_tensor_variable(i0) _i0 = tensor.as_tensor_variable(i0)
_i1 = tensor.as_tensor_variable(i1) _i1 = tensor.as_tensor_variable(i1)
return Apply(self, [x, _i0, _i1, amt], [x.type()]) return Apply(self, [x, _i0, _i1, amt], [x.type()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
x, i0, i1, amt = inputs x, i0, i1, amt = inputs
if not self.inplace: if not self.inplace:
...@@ -57,15 +75,22 @@ class IncDiagonalSubtensor(Op): ...@@ -57,15 +75,22 @@ class IncDiagonalSubtensor(Op):
xview = get_diagonal_subtensor_view(x, i0, i1) xview = get_diagonal_subtensor_view(x, i0, i1)
xview += amt xview += amt
output_storage[0][0] = x output_storage[0][0] = x
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
x, i0, i1, amt = inputs x, i0, i1, amt = inputs
gy = g_outputs[0] gy = g_outputs[0]
return [gy, None, None, diagonal_subtensor(gy, i0, i1)] return [gy, DisconnectedType()(), DisconnectedType()(),
diagonal_subtensor(gy, i0, i1)]
def connection_pattern(self, node):
rval = [[True], [False], [False], [True]]
return rval
inc_diagonal_subtensor = IncDiagonalSubtensor(False) inc_diagonal_subtensor = IncDiagonalSubtensor(False)
def conv3d(signals, filters, def conv3d(signals, filters,
signals_shape=None, filters_shape=None, signals_shape=None, filters_shape=None,
border_mode='valid', subsample=(1,1,1), **kwargs): border_mode='valid', subsample=(1, 1, 1), **kwargs):
""" """
Convolve spatio-temporal filters with a movie. Convolve spatio-temporal filters with a movie.
...@@ -87,8 +112,6 @@ def conv3d(signals, filters, ...@@ -87,8 +112,6 @@ def conv3d(signals, filters,
_signals_shape_5d = signals.shape if signals_shape is None else signals_shape _signals_shape_5d = signals.shape if signals_shape is None else signals_shape
_filters_shape_5d = filters.shape if filters_shape is None else filters_shape _filters_shape_5d = filters.shape if filters_shape is None else filters_shape
_signals_shape_4d = ( _signals_shape_4d = (
_signals_shape_5d[0] * _signals_shape_5d[1], _signals_shape_5d[0] * _signals_shape_5d[1],
_signals_shape_5d[2], _signals_shape_5d[2],
...@@ -106,29 +129,29 @@ def conv3d(signals, filters, ...@@ -106,29 +129,29 @@ def conv3d(signals, filters,
raise NotImplementedError('height and width bordermodes must match') raise NotImplementedError('height and width bordermodes must match')
out_4d = tensor.nnet.conv2d( out_4d = tensor.nnet.conv2d(
signals.reshape(_signals_shape_4d), signals.reshape(_signals_shape_4d),
filters.reshape(_filters_shape_4d), filters.reshape(_filters_shape_4d),
image_shape=_signals_shape_4d, image_shape=_signals_shape_4d,
filter_shape=_filters_shape_4d, filter_shape=_filters_shape_4d,
border_mode = border_mode[1]) #ignoring border_mode[2] border_mode = border_mode[1]) # ignoring border_mode[2]
# reshape the output to restore its original size # reshape the output to restore its original size
# shape = Ns, Ts, Nf, Tf, W-Wf+1, H-Hf+1 # shape = Ns, Ts, Nf, Tf, W-Wf+1, H-Hf+1
if border_mode[1] == 'valid': if border_mode[1] == 'valid':
out_tmp = out_4d.reshape(( out_tmp = out_4d.reshape((
_signals_shape_5d[0], # Ns _signals_shape_5d[0], # Ns
_signals_shape_5d[1], # Ts _signals_shape_5d[1], # Ts
_filters_shape_5d[0], # Nf _filters_shape_5d[0], # Nf
_filters_shape_5d[1], # Tf _filters_shape_5d[1], # Tf
_signals_shape_5d[3] - _filters_shape_5d[3] + 1, _signals_shape_5d[3] - _filters_shape_5d[3] + 1,
_signals_shape_5d[4] - _filters_shape_5d[4] + 1, _signals_shape_5d[4] - _filters_shape_5d[4] + 1,
)) ))
elif border_mode[1] == 'full': elif border_mode[1] == 'full':
out_tmp = out_4d.reshape(( out_tmp = out_4d.reshape((
_signals_shape_5d[0], #Ns _signals_shape_5d[0], # Ns
_signals_shape_5d[1], #Ts _signals_shape_5d[1], # Ts
_filters_shape_5d[0], #Nf _filters_shape_5d[0], # Nf
_filters_shape_5d[1], #Tf _filters_shape_5d[1], # Tf
_signals_shape_5d[3] + _filters_shape_5d[3] - 1, _signals_shape_5d[3] + _filters_shape_5d[3] - 1,
_signals_shape_5d[4] + _filters_shape_5d[4] - 1, _signals_shape_5d[4] + _filters_shape_5d[4] - 1,
)) ))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论