提交 80afdbeb authored 作者: Razvan Pascanu's avatar Razvan Pascanu

use user constructor only for float32

上级 6824b219
......@@ -76,21 +76,40 @@ class Scan(PureOp):
# Not that for mit_mot there are several output slices per
# output sequence
o = outputs[idx]
self.output_types.append(
typeConstructor(
broadcastable=(False,) + o.type.broadcastable,
dtype=o.type.dtype)
)
# typeConstructor constructs only CudaNdarray when code is
# runnnig on gpu, but for outputs of a dtype different than
# float32 this results into an error
if o.type.dtype in ['float32']:
self.output_types.append(
typeConstructor(
broadcastable=(False,) + o.type.broadcastable,
dtype=o.type.dtype))
else:
self.output_types.append(
tensorConstructor(
broadcastable=(False,) + o.type.broadcastable,
dtype=o.type.dtype))
idx += len(self.mit_mot_out_slices[jdx])
jdx += 1
# mit_sot / sit_sot / nit_sot
end = idx + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot
for o in outputs[idx:end]:
self.output_types.append(
typeConstructor(
broadcastable=(False,) + o.type.broadcastable,
dtype=o.type.dtype))
# typeConstructor constructs only CudaNdarray when code is
# runnnig on gpu, but for outputs of a dtype different than
# float32 this results into an error
if o.type.dtype in ['float32']:
self.output_types.append(
typeConstructor(
broadcastable=(False,) + o.type.broadcastable,
dtype=o.type.dtype))
else:
self.output_types.append(
tensorConstructor(
broadcastable=(False,) + o.type.broadcastable,
dtype=o.type.dtype))
# shared outputs + possibly the ending condition
for o in outputs[end:]:
self.output_types.append(o.type)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论