提交 0e226ff7 authored 作者: David Warde-Farley's avatar David Warde-Farley

Add infer_shape for sparse Ops that are missing it.

Closes gh-414.
上级 c19ecf50
...@@ -655,6 +655,13 @@ class CSMGrad(gof.op.Op): ...@@ -655,6 +655,13 @@ class CSMGrad(gof.op.Op):
grad = numpy.zeros_like(data) grad = numpy.zeros_like(data)
grad[self.kmap] = gout_data grad[self.kmap] = gout_data
g_data[0] = grad g_data[0] = grad
def infer_shape(self, node, shapes):
if self.kmap is None:
return [shapes[1]]
else:
return [shapes[0]]
csm_grad = CSMGrad csm_grad = CSMGrad
...@@ -936,6 +943,10 @@ class Transpose(gof.op.Op): ...@@ -936,6 +943,10 @@ class Transpose(gof.op.Op):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
assert _is_sparse_variable(x) and _is_sparse_variable(gz) assert _is_sparse_variable(x) and _is_sparse_variable(gz)
return transpose(gz), return transpose(gz),
def infer_shape(self, node, shapes):
return [shapes[0][::-1]]
transpose = Transpose() transpose = Transpose()
...@@ -957,6 +968,10 @@ class Neg(gof.op.Op): ...@@ -957,6 +968,10 @@ class Neg(gof.op.Op):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
assert _is_sparse_variable(x) and _is_sparse_variable(gz) assert _is_sparse_variable(x) and _is_sparse_variable(gz)
return -gz, return -gz,
def infer_shape(self, node, shapes):
return [shapes[0]]
neg = Neg() neg = Neg()
...@@ -989,6 +1004,10 @@ class AddSS(gof.op.Op): ...@@ -989,6 +1004,10 @@ class AddSS(gof.op.Op):
assert _is_sparse_variable(x) and _is_sparse_variable(y) assert _is_sparse_variable(x) and _is_sparse_variable(y)
assert _is_sparse_variable(gz) assert _is_sparse_variable(gz)
return gz, gz return gz, gz
def infer_shape(self, node, shapes):
return [shapes[0]]
add_s_s = AddSS() add_s_s = AddSS()
...@@ -1023,6 +1042,10 @@ class AddSD(gof.op.Op): ...@@ -1023,6 +1042,10 @@ class AddSD(gof.op.Op):
assert _is_sparse_variable(x) and _is_dense_variable(y) assert _is_sparse_variable(x) and _is_dense_variable(y)
assert _is_dense_variable(gz) assert _is_dense_variable(gz)
return sp_ones_like(x) * gz, gz return sp_ones_like(x) * gz, gz
def infer_shape(self, node, shapes):
return [shapes[0]]
add_s_d = AddSD() add_s_d = AddSD()
...@@ -1080,6 +1103,10 @@ class MulSS(gof.op.Op): ...@@ -1080,6 +1103,10 @@ class MulSS(gof.op.Op):
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
return y * gz, x * gz return y * gz, x * gz
def infer_shape(self, node, shapes):
return [shapes[0]]
mul_s_s = MulSS() mul_s_s = MulSS()
...@@ -1155,6 +1182,10 @@ class MulSD(gof.op.Op): ...@@ -1155,6 +1182,10 @@ class MulSD(gof.op.Op):
assert _is_sparse_variable(x) and _is_dense_variable(y) assert _is_sparse_variable(x) and _is_dense_variable(y)
assert _is_sparse_variable(gz) assert _is_sparse_variable(gz)
return y * gz, x * gz return y * gz, x * gz
def infer_shape(self, node, shapes):
return [shapes[0]]
mul_s_d = MulSD() mul_s_d = MulSD()
...@@ -1259,6 +1290,10 @@ class StructuredDot(gof.Op): ...@@ -1259,6 +1290,10 @@ class StructuredDot(gof.Op):
# ga = g_out x b.T # ga = g_out x b.T
# gb = a.T x g_out # gb = a.T x g_out
return [structured_dot_grad(a, b, g_out), structured_dot(a.T, g_out)] return [structured_dot_grad(a, b, g_out), structured_dot(a.T, g_out)]
def infer_shape(self, node, shapes):
return [(shapes[0][0], shapes[1][1])]
_structured_dot = StructuredDot() _structured_dot = StructuredDot()
...@@ -1753,6 +1788,10 @@ class StructuredDotGradCSC(gof.Op): ...@@ -1753,6 +1788,10 @@ class StructuredDotGradCSC(gof.Op):
} }
""" % dict(locals(), **sub) """ % dict(locals(), **sub)
def infer_shape(self, node, shapes):
return [shapes[0]]
sdg_csc = StructuredDotGradCSC() sdg_csc = StructuredDotGradCSC()
...@@ -1866,6 +1905,8 @@ class StructuredDotGradCSR(gof.Op): ...@@ -1866,6 +1905,8 @@ class StructuredDotGradCSR(gof.Op):
""" % dict(locals(), **sub) """ % dict(locals(), **sub)
def infer_shape(self, node, shapes):
return [shapes[0]]
sdg_csr = StructuredDotGradCSR() sdg_csr = StructuredDotGradCSR()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论