提交 aaa49e3c authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fix for some shape-related errors on 32-bit

上级 2d774b33
...@@ -920,6 +920,28 @@ class ShapeFeature(object): ...@@ -920,6 +920,28 @@ class ShapeFeature(object):
+ ' != len(node.outputs) = ' + ' != len(node.outputs) = '
+ str(len(node.outputs))) + str(len(node.outputs)))
# Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor` optimization does not fail.
new_shape = []
for sh_idx, sh in enumerate(o_shapes):
if sh is None:
continue
for i, d in enumerate(sh):
# Note: we ignore any shape element that is not typed (i.e. does
# not have a 'dtype' attribute). This means there may still
# remain int elements that are int32 on 32-bit platforms, but
# this works with `local_useless_subtensor`, so for now we
# keep it this way. See #266 for a better long-term fix.
if getattr(d, 'dtype', 'int64') != 'int64':
assert d.dtype in theano.tensor.int_dtypes
new_shape += sh[len(new_shape):i + 1]
new_shape[i] = theano.tensor.cast(d, 'int64')
if new_shape:
# We replace the shape with wrong dtype by the one with 'int64'.
new_shape += sh[len(new_shape):]
o_shapes[sh_idx] = tuple(new_shape)
new_shape = []
for r, s in izip(node.outputs, o_shapes): for r, s in izip(node.outputs, o_shapes):
self.set_shape(r, s) self.set_shape(r, s)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论