提交 52974f37 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

the R_op can return None for non differentiable outs

上级 f2257003
......@@ -1591,7 +1591,14 @@ class Scan(PureOp):
ib = 0
e = 1 + self.n_seqs
ie = self.n_seqs
scan_seqs = inputs[b:e] + eval_points[b:e]
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
clean_eval_points.append(evp)
else:
clean_eval_points.append(inp.zeros_like())
scan_seqs = inputs[b:e] + clean_eval_points
inner_seqs = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# MIT_MOT sequences ...
......@@ -1600,7 +1607,14 @@ class Scan(PureOp):
ib = ie
ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[:self.n_mit_mot]]))
scan_mit_mot = inputs[b:e] + eval_points[b:e]
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
clean_eval_points.append(evp)
else:
clean_eval_points.append(inp.zeros_like())
scan_mit_mot = inputs[b:e] + clean_eval_points
inner_mit_mot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# MIT_SOT sequences ...
......@@ -1610,6 +1624,13 @@ class Scan(PureOp):
ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot:\
self.n_mit_mot + self.n_mit_sot]]))
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
clean_eval_points.append(evp)
else:
clean_eval_points.append(inp.zeros_like())
scan_mit_sot = inputs[b:e] + eval_points[b:e]
inner_mit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
......@@ -1618,7 +1639,14 @@ class Scan(PureOp):
e = e + self.n_sit_sot
ib = ie
ie = ie + self.n_sit_sot
scan_sit_sot = inputs[b:e] + eval_points[b:e]
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
clean_eval_points.append(evp)
else:
clean_eval_points.append(inp.zeros_like())
scan_sit_sot = inputs[b:e] + clean_eval_points
inner_sit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
#Shared outs ...
......@@ -1635,8 +1663,15 @@ class Scan(PureOp):
scan_nit_sot = inputs[b:e] * 2
# All other arguments
scan_other = inputs[e:] + eval_points[e:]
inner_other = self_inputs[ie:] + inner_eval_points[ie:]
clean_eval_points = []
for inp, evp in zip(inputs[e:], eval_points[e:]):
if evp is not None:
clean_eval_points.append(evp)
else:
clean_eval_points.append(inp.zeros_like())
scan_other = inputs[e:] + clean_eval_points
# inner_eval_points do not have entries for shared variables
inner_other = self_inputs[ie:] + inner_eval_points[ib:]
# Outputs
n_mit_mot_outs = int(numpy.sum([len(x) for x in
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论