提交 8ff30ed8 authored 作者: Frederic's avatar Frederic

Remove badly implemented infer_shape as they are not needed and generated error when run.

上级 b736b12e
...@@ -1563,6 +1563,8 @@ class Usmm(gof.op.Op): ...@@ -1563,6 +1563,8 @@ class Usmm(gof.op.Op):
x or y are sparse matrix(the other can be sparse or dense) x or y are sparse matrix(the other can be sparse or dense)
z is a dense matrix z is a dense matrix
alpha is a scalar alpha is a scalar
:note: We don't implement the infer_shape as it is inserted by optimization only
""" """
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -1576,19 +1578,6 @@ class Usmm(gof.op.Op): ...@@ -1576,19 +1578,6 @@ class Usmm(gof.op.Op):
def __str__(self): def __str__(self):
return 'Usmm{no_inplace}' return 'Usmm{no_inplace}'
def infer_shape(self, node, shapes):
xshp, yshp = shapes
x, y = node.inputs
if x.ndim == 2 and y.ndim == 2:
return [(xshp[0], yshp[1])]
if x.ndim == 1 and y.ndim == 2:
return [(yshp[1],)]
if x.ndim == 2 and y.ndim == 1:
return [(xshp[0],)]
if x.ndim == 1 and y.ndim == 1:
return [()]
raise NotImplementedError()
def make_node(self, alpha, x, y, z): def make_node(self, alpha, x, y, z):
if not _is_sparse_variable(x) and not _is_sparse_variable(y): if not _is_sparse_variable(x) and not _is_sparse_variable(y):
# If x and y are tensor, we don't want to use this class # If x and y are tensor, we don't want to use this class
...@@ -1644,6 +1633,8 @@ class UsmmCscDense(gof.Op): ...@@ -1644,6 +1633,8 @@ class UsmmCscDense(gof.Op):
x are sparse matrix x are sparse matrix
y, z is a dense matrix y, z is a dense matrix
alpha is a scalar alpha is a scalar
:note: We don't implement the infer_shape as it is inserted by optimization only
""" """
def __init__(self, inplace): def __init__(self, inplace):
self.inplace = inplace self.inplace = inplace
...@@ -1662,19 +1653,6 @@ class UsmmCscDense(gof.Op): ...@@ -1662,19 +1653,6 @@ class UsmmCscDense(gof.Op):
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ self.inplace return hash(type(self)) ^ self.inplace
def infer_shape(self, node, shapes):
xshp, yshp = shapes
x, y = node.inputs
if x.ndim == 2 and y.ndim == 2:
return [(xshp[0], yshp[1])]
if x.ndim == 1 and y.ndim == 2:
return [(yshp[1],)]
if x.ndim == 2 and y.ndim == 1:
return [(xshp[0],)]
if x.ndim == 1 and y.ndim == 1:
return [()]
raise NotImplementedError()
def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z): def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z):
alpha = tensor.as_tensor_variable(alpha) alpha = tensor.as_tensor_variable(alpha)
x_val = tensor.as_tensor_variable(x_val) x_val = tensor.as_tensor_variable(x_val)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论