提交 9c9d8395 authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

added infer shape

上级 7552cf2a
...@@ -1467,6 +1467,19 @@ class Dot(gof.op.Op): ...@@ -1467,6 +1467,19 @@ class Dot(gof.op.Op):
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
def infer_shape(self, node, shapes):
xshp, yshp = shapes
x, y = node.inputs
if x.ndim == 2 and y.ndim == 2:
return [(xshp[0], yshp[1])]
if x.ndim == 1 and y.ndim == 2:
return [(yshp[1],)]
if x.ndim == 2 and y.ndim == 1:
return [(xshp[0],)]
if x.ndim == 1 and y.ndim == 1:
return [()]
raise NotImplementedError()
def make_node(self, x, y): def make_node(self, x, y):
dtype_out = scalar.upcast(x.type.dtype, y.type.dtype) dtype_out = scalar.upcast(x.type.dtype, y.type.dtype)
...@@ -1539,6 +1552,19 @@ class Usmm(gof.op.Op): ...@@ -1539,6 +1552,19 @@ class Usmm(gof.op.Op):
def __str__(self): def __str__(self):
return 'Usmm{no_inplace}' return 'Usmm{no_inplace}'
def infer_shape(self, node, shapes):
xshp, yshp = shapes
x, y = node.inputs
if x.ndim == 2 and y.ndim == 2:
return [(xshp[0], yshp[1])]
if x.ndim == 1 and y.ndim == 2:
return [(yshp[1],)]
if x.ndim == 2 and y.ndim == 1:
return [(xshp[0],)]
if x.ndim == 1 and y.ndim == 1:
return [()]
raise NotImplementedError()
def make_node(self, alpha, x, y, z): def make_node(self, alpha, x, y, z):
if not _is_sparse_variable(x) and not _is_sparse_variable(y): if not _is_sparse_variable(x) and not _is_sparse_variable(y):
# If x and y are tensor, we don't want to use this class # If x and y are tensor, we don't want to use this class
...@@ -1603,11 +1629,25 @@ class UsmmCscDense(gof.Op): ...@@ -1603,11 +1629,25 @@ class UsmmCscDense(gof.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) and self.inplace == other.inplace return (type(self) == type(other)) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self)) ^ self.inplace
def infer_shape(self, node, shapes):
xshp, yshp = shapes
x, y = node.inputs
if x.ndim == 2 and y.ndim == 2:
return [(xshp[0], yshp[1])]
if x.ndim == 1 and y.ndim == 2:
return [(yshp[1],)]
if x.ndim == 2 and y.ndim == 1:
return [(xshp[0],)]
if x.ndim == 1 and y.ndim == 1:
return [()]
raise NotImplementedError()
def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z): def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z):
alpha = tensor.as_tensor_variable(alpha)
x_val = tensor.as_tensor_variable(x_val) x_val = tensor.as_tensor_variable(x_val)
x_ind = tensor.as_tensor_variable(x_ind) x_ind = tensor.as_tensor_variable(x_ind)
x_ptr = tensor.as_tensor_variable(x_ptr) x_ptr = tensor.as_tensor_variable(x_ptr)
x_nrows = tensor.as_tensor_variable(x_nrows)
y = tensor.as_tensor_variable(y) y = tensor.as_tensor_variable(y)
z = tensor.as_tensor_variable(z) z = tensor.as_tensor_variable(z)
assert x_ind.dtype == 'int32' assert x_ind.dtype == 'int32'
...@@ -1620,7 +1660,6 @@ class UsmmCscDense(gof.Op): ...@@ -1620,7 +1660,6 @@ class UsmmCscDense(gof.Op):
dtype_out = scalar.upcast(alpha.type.dtype, x_val.type.dtype, dtype_out = scalar.upcast(alpha.type.dtype, x_val.type.dtype,
y.type.dtype, z.type.dtype) y.type.dtype, z.type.dtype)
alpha = tensor.as_tensor_variable(alpha)
if self.inplace: if self.inplace:
assert z.type.dtype == dtype_out assert z.type.dtype == dtype_out
...@@ -1631,10 +1670,14 @@ class UsmmCscDense(gof.Op): ...@@ -1631,10 +1670,14 @@ class UsmmCscDense(gof.Op):
if dtype_out != x_val.type.dtype: if dtype_out != x_val.type.dtype:
x_val = tensor.cast(x_val, dtype_out) x_val = tensor.cast(x_val, dtype_out)
if dtype_out != y.type.dtype: if dtype_out != y.type.dtype:
raise NotImplementedError("We need sparse cast to be implemented!") y = tensor.cast(y, dtype_out)
if dtype_out != z.type.dtype: if dtype_out != z.type.dtype:
z = tensor.cast(z, dtype_out) z = tensor.cast(z, dtype_out)
if node.inputs[1].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for x_val')
if node.inputs[5].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for y')
r = gof.Apply(self, [alpha, x_val, x_ind, x_ptr, x_nrows, y, z], r = gof.Apply(self, [alpha, x_val, x_ind, x_ptr, x_nrows, y, z],
[tensor.tensor(dtype_out, (False, y.type.broadcastable[1]))]) [tensor.tensor(dtype_out, (False, y.type.broadcastable[1]))])
...@@ -1716,7 +1759,7 @@ class UsmmCscDense(gof.Op): ...@@ -1716,7 +1759,7 @@ class UsmmCscDense(gof.Op):
{PyErr_SetString(PyExc_NotImplementedError, "x's number of columns doesn't match y's rows"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "x's number of columns doesn't match y's rows"); %(fail)s;}
if (%(z)s->dimensions[0] != ((npy_int32 *)%(x_nrows)s->data)[0] || %(z)s->dimensions[1] != %(y)s->dimensions[1]) if (%(z)s->dimensions[0] != ((npy_int32 *)%(x_nrows)s->data)[0] || %(z)s->dimensions[1] != %(y)s->dimensions[1])
{PyErr_SetString(PyExc_NotImplementedError, "The dimension of z and the output must match"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "The dimension of the allocated output doesn't match the correct output size."); %(fail)s;}
if (PyArray_SIZE(%(alpha)s) != 1) if (PyArray_SIZE(%(alpha)s) != 1)
{PyErr_SetString(PyExc_NotImplementedError, "The number of element in alpha must be 1"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "The number of element in alpha must be 1"); %(fail)s;}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论