提交 56f011a4 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

figuring out the right dtype

上级 39a2eb78
......@@ -1291,17 +1291,20 @@ class Scan(PureOp):
if idx >= self.n_mit_mot_outs:
Xt_placeholder = Xt.type()
Xts.append(Xt_placeholder)
Xtm1_pos = self.get_input_pos(idx)
if Xtm1_pos >= 0:
Xtm1 = self_inputs[Xtm1_pos]
# It is possible that X[t] is not actually a function of
# x[t-1], case in which we can not rely on this information
try:
tmp = tensor.grad(Xt.sum(), Xtm1)
except ValueError:
tmp = Xt
dC_dXt = safe_new(tmp)
if Xt not in self.inner_nitsot_outs(self_outputs):
dtypes = []
states = (self.inner_mitmot(self_inputs) +
self.inner_mitsot(self_inputs) +
self.inner_sitsot(self_inputs))
for pos, inp in enumerate(states):
if inp in theano.gof.graph.inputs([Xt]):
oidx = self.get_output_pos(pos)
if not isinstance(dC_douts[oidx].type,
DisconnectedType):
dtypes.append(dC_douts[oidx].dtype)
new_dtype = theano.scalar.upcast(*dtypes)
dC_dXt = safe_new(Xt, dtype=new_dtype)
else:
if isinstance(dC_douts[idx].type, DisconnectedType):
continue
......@@ -1331,7 +1334,17 @@ class Scan(PureOp):
dC_dinps_t[dx] = tmp
# construct dX_dtm1
dC_dXtm1s = [x.type() for x in dC_dinps_t[self.n_seqs:]]
dC_dXtm1s = []
for pos, x in enumerate(dC_dinps_t[self.n_seqs:]):
opos = self.get_output_pos(pos)
if opos >= 0:
dC_dXtm1s.append(dC_dXts[opos].type())
if x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \
tensor.cast(x,
dtype=dC_dXts[opos].dtype)
else:
dC_dXtm1s.append(x.type())
for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论