提交 53ec8e33 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not try to infer artificial connection patterns in OpFromGraph

上级 ab304cb9
......@@ -890,30 +890,9 @@ class OpFromGraph(Op, HasInnerGraph):
if self._connection_pattern is not None:
return self._connection_pattern
inp_len = len(self.inner_inputs)
out_len = len(self.inner_outputs)
cpmat_self = io_connection_pattern(self.inner_inputs, self.inner_outputs)
lop_op = self.get_lop_op()
cpmat_grad = io_connection_pattern(
lop_op.inner_inputs[inp_len:], lop_op.inner_outputs
)
# cpmat_self |= cpmat_grad.T
# cpmat_self &= out_is_disconnected
for i, t in enumerate(self._lop_op_stypes_l):
if t is not None:
if isinstance(t.type, DisconnectedType):
for o in range(out_len):
cpmat_self[i][o] = False
for o in range(out_len):
cpmat_self[i][o] |= cpmat_grad[o][i]
# TODO in case DisconnectedType is implemented for R_op,
# self._rop_op_stypes_l self._rop_op should considered for
# connection_pattern
return list(map(list, cpmat_self))
ret = io_connection_pattern(self.inner_inputs, self.inner_outputs)
self._connection_pattern = ret
return ret
def infer_shape(self, fgraph, node, shapes):
# TODO: Use `fgraph.shape_feature` to do this instead.
......
......@@ -244,6 +244,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
[x, w, b],
[x * w + b],
grad_overrides=[go1, NullType()(), DisconnectedType()()],
# This is a fake override, so a fake connection_pattern must be provided as well
connection_pattern=[[True], [True], [False]],
)
zz2 = pt_sum(op_linear2(xx, ww, bb))
dx2, dw2, db2 = grad(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论