提交 8498dfdd authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

replaced multiple checks with simpler check

上级 9f401058
......@@ -1662,6 +1662,9 @@ class UsmmCscDense(gof.Op):
dtype_out = scalar.upcast(alpha.type.dtype, x_val.type.dtype,
y.type.dtype, z.type.dtype)
if dtype_out in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported in operands')
if self.inplace:
assert z.type.dtype == dtype_out
......@@ -1675,11 +1678,6 @@ class UsmmCscDense(gof.Op):
if dtype_out != z.type.dtype:
z = tensor.cast(z, dtype_out)
if node.inputs[1].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for x_val')
if node.inputs[5].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for y')
r = gof.Apply(self, [alpha, x_val, x_ind, x_ptr, x_nrows, y, z],
[tensor.tensor(dtype_out, (False, y.type.broadcastable[1]))])
return r
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论