提交 c19ecf50 authored 作者: David Warde-Farley's avatar David Warde-Farley

Uniformize signature of existing infer_shapes.

Remove the tuple argument unpacking which is deprecated in Python 3.
上级 0f86ecd9
...@@ -714,8 +714,9 @@ class DenseFromSparse(gof.op.Op): ...@@ -714,8 +714,9 @@ class DenseFromSparse(gof.op.Op):
else: else:
return [SparseFromDense(x.type.format)(gz)] return [SparseFromDense(x.type.format)(gz)]
def infer_shape(self, node, (ishape,)): def infer_shape(self, node, shapes):
return [ishape] return [shapes[0]]
dense_from_sparse = DenseFromSparse() dense_from_sparse = DenseFromSparse()
...@@ -749,8 +750,9 @@ class SparseFromDense(gof.op.Op): ...@@ -749,8 +750,9 @@ class SparseFromDense(gof.op.Op):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return dense_from_sparse(gz), return dense_from_sparse(gz),
def infer_shape(self, node, (ishape,)): def infer_shape(self, node, shapes):
return [ishape] return [shapes[0]]
csr_from_dense = SparseFromDense('csr') csr_from_dense = SparseFromDense('csr')
csc_from_dense = SparseFromDense('csc') csc_from_dense = SparseFromDense('csc')
...@@ -870,7 +872,7 @@ class GetItemScalar(gof.op.Op): ...@@ -870,7 +872,7 @@ class GetItemScalar(gof.op.Op):
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def infer_shape(self, node, i0_shapes): def infer_shape(self, node, shapes):
return [()] return [()]
def make_node(self, x, index): def make_node(self, x, index):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论