提交 3bc33ddb authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged

...@@ -343,22 +343,29 @@ class CSM(gof.Op): ...@@ -343,22 +343,29 @@ class CSM(gof.Op):
""" """
data = tensor.as_tensor_variable(data) data = tensor.as_tensor_variable(data)
# Note that we use `view(numpy.int32)` instead of providing the 'int32'
# dtype to `numpy.asarray`. This is because on some computers (e.g. a # Note that we use `view(numpy.int32)` in addition to providing the
# Windows 32 bits machine), we can have the following assert fail: # 'int32' dtype to `numpy.asarray`. This is because on some computers
# (e.g. a Windows 32 bits machine), we can have the following assert
# fail:
# x = numpy.array([0], dtype=numpy.intc) # x = numpy.array([0], dtype=numpy.intc)
# y = numpy.asarray(x, dtype=numpy.int32) # y = numpy.asarray(x, dtype=numpy.int32)
# assert y.dtype.num == numpy.dtype(numpy.int32).num # assert y.dtype.num == numpy.dtype(numpy.int32).num
# while the assert does *not* fail when replacing the second line by: # while the assert does *not* fail when replacing the second line by:
# y = numpy.asarray(x).view(numpy.int32) # y = numpy.asarray(x, dtype='int32').view(numpy.int32)
# This is a known defect in Numpy. For more information see ticket # This is a known defect in Numpy. For more information see ticket
# http://projects.scipy.org/numpy/ticket/870 # http://projects.scipy.org/numpy/ticket/870
# Note also that it is important to keep "dtype='int32'" when calling
# `numpy.asarray`. This is because `view` is only some kind of cast to
# the exact data type we want to use. If a conversion is required (e.g.
# from int64 to int32), it must be done in the call to `numpy.asarray`.
if not isinstance(indices, tensor.TensorVariable): if not isinstance(indices, tensor.TensorVariable):
indices = numpy.asarray(indices).view(numpy.int32) indices = numpy.asarray(indices, dtype='int32').view(numpy.int32)
if not isinstance(indptr, tensor.TensorVariable): if not isinstance(indptr, tensor.TensorVariable):
indptr = numpy.asarray(indptr).view(numpy.int32) indptr = numpy.asarray(indptr, dtype='int32').view(numpy.int32)
if not isinstance(shape, tensor.TensorVariable): if not isinstance(shape, tensor.TensorVariable):
shape = numpy.asarray(shape).view(numpy.int32) shape = numpy.asarray(shape, dtype='int32').view(numpy.int32)
indices = tensor.as_tensor_variable(indices) indices = tensor.as_tensor_variable(indices)
indptr = tensor.as_tensor_variable(indptr) indptr = tensor.as_tensor_variable(indptr)
shape = tensor.as_tensor_variable(shape) shape = tensor.as_tensor_variable(shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论