提交 56f011a4 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

figuring out the right dtype

上级 39a2eb78
...@@ -1291,17 +1291,20 @@ class Scan(PureOp): ...@@ -1291,17 +1291,20 @@ class Scan(PureOp):
if idx >= self.n_mit_mot_outs: if idx >= self.n_mit_mot_outs:
Xt_placeholder = Xt.type() Xt_placeholder = Xt.type()
Xts.append(Xt_placeholder) Xts.append(Xt_placeholder)
if Xt not in self.inner_nitsot_outs(self_outputs):
Xtm1_pos = self.get_input_pos(idx) dtypes = []
if Xtm1_pos >= 0: states = (self.inner_mitmot(self_inputs) +
Xtm1 = self_inputs[Xtm1_pos] self.inner_mitsot(self_inputs) +
# It is possible that X[t] is not actually a function of self.inner_sitsot(self_inputs))
# x[t-1], case in which we can not rely on this information
try: for pos, inp in enumerate(states):
tmp = tensor.grad(Xt.sum(), Xtm1) if inp in theano.gof.graph.inputs([Xt]):
except ValueError: oidx = self.get_output_pos(pos)
tmp = Xt if not isinstance(dC_douts[oidx].type,
dC_dXt = safe_new(tmp) DisconnectedType):
dtypes.append(dC_douts[oidx].dtype)
new_dtype = theano.scalar.upcast(*dtypes)
dC_dXt = safe_new(Xt, dtype=new_dtype)
else: else:
if isinstance(dC_douts[idx].type, DisconnectedType): if isinstance(dC_douts[idx].type, DisconnectedType):
continue continue
...@@ -1331,7 +1334,17 @@ class Scan(PureOp): ...@@ -1331,7 +1334,17 @@ class Scan(PureOp):
dC_dinps_t[dx] = tmp dC_dinps_t[dx] = tmp
# construct dX_dtm1 # construct dX_dtm1
dC_dXtm1s = [x.type() for x in dC_dinps_t[self.n_seqs:]] dC_dXtm1s = []
for pos, x in enumerate(dC_dinps_t[self.n_seqs:]):
opos = self.get_output_pos(pos)
if opos >= 0:
dC_dXtm1s.append(dC_dXts[opos].type())
if x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \
tensor.cast(x,
dtype=dC_dXts[opos].dtype)
else:
dC_dXtm1s.append(x.type())
for dx, dC_dXtm1 in enumerate(dC_dXtm1s): for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1 dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op # Construct scan op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论