提交 d59aaf11 authored 作者: Frederic's avatar Frederic

Safer indexing. If we append more information, it will work correctly.

上级 58782af5
...@@ -130,9 +130,9 @@ class StructuredDotCSC(gof.Op): ...@@ -130,9 +130,9 @@ class StructuredDotCSC(gof.Op):
if node.inputs[4].type.dtype in ('complex64', 'complex128'): if node.inputs[4].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for b') raise NotImplementedError('Complex types are not supported for b')
typenum_z = node.outputs[0].type.dtype_specs()[-1] # retrieve dtype number typenum_z = node.outputs[0].type.dtype_specs()[2] # retrieve dtype number
typenum_a_val = node.inputs[0].type.dtype_specs()[-1] # retrieve dtype number typenum_a_val = node.inputs[0].type.dtype_specs()[2] # retrieve dtype number
typenum_b = node.inputs[4].type.dtype_specs()[-1] # retrieve dtype number typenum_b = node.inputs[4].type.dtype_specs()[2] # retrieve dtype number
rval = """ rval = """
...@@ -318,7 +318,7 @@ class StructuredDotCSR(gof.Op): ...@@ -318,7 +318,7 @@ class StructuredDotCSR(gof.Op):
@param sub: TODO, not too sure, something to do with weave probably @param sub: TODO, not too sure, something to do with weave probably
""" """
# retrieve dtype number # retrieve dtype number
typenum_z = tensor.TensorType(self.dtype_out, []).dtype_specs()[-1] typenum_z = tensor.TensorType(self.dtype_out, []).dtype_specs()[2]
if node.inputs[0].type.dtype in ('complex64', 'complex128'): if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for a_val') raise NotImplementedError('Complex types are not supported for a_val')
if node.inputs[3].type.dtype in ('complex64', 'complex128'): if node.inputs[3].type.dtype in ('complex64', 'complex128'):
...@@ -550,11 +550,11 @@ class UsmmCscDense(gof.Op): ...@@ -550,11 +550,11 @@ class UsmmCscDense(gof.Op):
conv_type = "double" conv_type = "double"
axpy = "daxpy_" axpy = "daxpy_"
# retrieve dtype numbers # retrieve dtype numbers
typenum_alpha = node.inputs[0].type.dtype_specs()[-1] typenum_alpha = node.inputs[0].type.dtype_specs()[2]
typenum_x_val = node.inputs[1].type.dtype_specs()[-1] typenum_x_val = node.inputs[1].type.dtype_specs()[2]
typenum_y = node.inputs[5].type.dtype_specs()[-1] typenum_y = node.inputs[5].type.dtype_specs()[2]
typenum_z = node.inputs[6].type.dtype_specs()[-1] typenum_z = node.inputs[6].type.dtype_specs()[2]
typenum_zn = node.outputs[0].type.dtype_specs()[-1] typenum_zn = node.outputs[0].type.dtype_specs()[2]
inplace = int(self.inplace) inplace = int(self.inplace)
...@@ -761,7 +761,7 @@ class CSMGradC(gof.Op): ...@@ -761,7 +761,7 @@ class CSMGradC(gof.Op):
def c_code(self, node, name, (a_val, a_ind, a_ptr, a_dim, def c_code(self, node, name, (a_val, a_ind, a_ptr, a_dim,
b_val, b_ind, b_ptr, b_dim), (z,), sub): b_val, b_ind, b_ptr, b_dim), (z,), sub):
# retrieve dtype number # retrieve dtype number
typenum_z = node.outputs[0].type.dtype_specs()[-1] typenum_z = node.outputs[0].type.dtype_specs()[2]
if node.inputs[0].type.dtype in ('complex64', 'complex128'): if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for a_val') raise NotImplementedError('Complex types are not supported for a_val')
if node.inputs[3].type.dtype in ('complex64', 'complex128'): if node.inputs[3].type.dtype in ('complex64', 'complex128'):
...@@ -1558,15 +1558,15 @@ class SamplingDotCSR(gof.Op): ...@@ -1558,15 +1558,15 @@ class SamplingDotCSR(gof.Op):
cdot = "ddot_" cdot = "ddot_"
# retrieve dtype number # retrieve dtype number
typenum_x = node.inputs[0].type.dtype_specs()[-1] typenum_x = node.inputs[0].type.dtype_specs()[2]
typenum_y = node.inputs[1].type.dtype_specs()[-1] typenum_y = node.inputs[1].type.dtype_specs()[2]
typenum_p = node.inputs[2].type.dtype_specs()[-1] typenum_p = node.inputs[2].type.dtype_specs()[2]
typenum_zd = tensor.TensorType(node.outputs[0].dtype, typenum_zd = tensor.TensorType(node.outputs[0].dtype,
[]).dtype_specs()[-1] []).dtype_specs()[2]
typenum_zi = tensor.TensorType(node.outputs[1].dtype, typenum_zi = tensor.TensorType(node.outputs[1].dtype,
[]).dtype_specs()[-1] []).dtype_specs()[2]
typenum_zp = tensor.TensorType(node.outputs[2].dtype, typenum_zp = tensor.TensorType(node.outputs[2].dtype,
[]).dtype_specs()[-1] []).dtype_specs()[2]
rval = """ rval = """
if (PyArray_NDIM(%(x)s) != 2) { if (PyArray_NDIM(%(x)s) != 2) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论