提交 8759f6a7 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

return disconnected type instead of undefined

上级 781f9e8b
...@@ -1249,7 +1249,8 @@ class Scan(PureOp): ...@@ -1249,7 +1249,8 @@ class Scan(PureOp):
return ipos + opos return ipos + opos
def connection_pattern(self, node): def connection_pattern(self, node):
connection_pattern = [[True for output in node.outputs]] # The gradient wrt to n_steps is disconnected
connection_pattern = [[False for output in node.outputs]]
connection_pattern += [[False for output in node.outputs] connection_pattern += [[False for output in node.outputs]
for x in node.inputs[1:]] for x in node.inputs[1:]]
...@@ -1286,8 +1287,7 @@ class Scan(PureOp): ...@@ -1286,8 +1287,7 @@ class Scan(PureOp):
else: else:
e = len(self.tap_array[0]) e = len(self.tap_array[0])
p = iidx p = iidx
if (node.inputs[iidx + 1] in self.outer_nitsot(node) or if node.inputs[iidx + 1] in self.outer_nitsot(node)
node.inputs[iidx + 1] in self.outer_shared(node)):
return None return None
if node.inputs[iidx + 1] in self.outer_non_seqs(node): if node.inputs[iidx + 1] in self.outer_non_seqs(node):
loc_idx = self.outer_non_seqs(node).index( loc_idx = self.outer_non_seqs(node).index(
...@@ -1298,8 +1298,10 @@ class Scan(PureOp): ...@@ -1298,8 +1298,10 @@ class Scan(PureOp):
s = e s = e
if p < self.n_seqs: if p < self.n_seqs:
e += 1 e += 1
else: elif p - self.n_seqs < len(self.tap_array):
e += len(self.tap_array[p - self.n_seqs]) e += len(self.tap_array[p - self.n_seqs])
else:
e += 1
return self.inputs[s:e] return self.inputs[s:e]
for oidx, out in enumerate(node.outputs): for oidx, out in enumerate(node.outputs):
...@@ -1308,8 +1310,8 @@ class Scan(PureOp): ...@@ -1308,8 +1310,8 @@ class Scan(PureOp):
ils = _get_inner_inps(iidx) ils = _get_inner_inps(iidx)
if ils is None: if ils is None:
# The gradient should be undefined, not disconnected # The gradient should be disconnected
connection_pattern[iidx + 1][oidx] = True connection_pattern[iidx + 1][oidx] = False
else: else:
for inner_out in ols: for inner_out in ols:
if hasattr(inner_out, 'dtype'): if hasattr(inner_out, 'dtype'):
...@@ -1662,7 +1664,7 @@ class Scan(PureOp): ...@@ -1662,7 +1664,7 @@ class Scan(PureOp):
if type(outputs) not in (list, tuple): if type(outputs) not in (list, tuple):
outputs = [outputs] outputs = [outputs]
# Re-order the gradients correctly # Re-order the gradients correctly
gradients = [grad_undefined(self, 0, inputs[0], 'Number of steps')] gradients = [DisconnectedType()()]
offset = (self.n_mit_mot + offset = (self.n_mit_mot +
self.n_mit_sot + self.n_mit_sot +
...@@ -1702,9 +1704,7 @@ class Scan(PureOp): ...@@ -1702,9 +1704,7 @@ class Scan(PureOp):
'Shared Variable with update') 'Shared Variable with update')
for x in xrange(self.n_shared_outs)] for x in xrange(self.n_shared_outs)]
start = len(gradients) start = len(gradients)
gradients += [ gradients += [DisconnectedType()()
grad_undefined(self, x + start, inputs[x + start],
'Dimension of memory buffer for output')
for x in xrange(self.n_nit_sot)] for x in xrange(self.n_nit_sot)]
begin = end begin = end
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论