提交 fe5cee6e authored 作者: ChienliMa's avatar ChienliMa

Draft of OpFromGraph.connection_pattern and testcase. Test pass.

上级 25c208fd
import numpy
import theano import theano
from theano import gof from theano import gof
from theano.compat import izip from theano.compat import izip
from theano.compile.function_module import orig_function from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared from theano.compile import SharedVariable, rebuild_collect_shared
from theano.gof import ops_with_inner_function from theano.gof import ops_with_inner_function
from theano.gof import FunctionGraph
class OpFromGraph(gof.Op): class OpFromGraph(gof.Op):
...@@ -134,6 +136,59 @@ class OpFromGraph(gof.Op): ...@@ -134,6 +136,59 @@ class OpFromGraph(gof.Op):
# we wont need this copy anymore # we wont need this copy anymore
output[0] = variable.copy() output[0] = variable.copy()
def connection_pattern(self, node):
"""
Connection_pattern is hard to calculate. In the function, we calculate
the transpose of connection_pattern, where M[output_index,input_index]
indicates whether input with index i affects output with index i.
At last we return the transpose of final result
"""
# or ori_inputs because user do not customize sharejvariable
fgraph = FunctionGraph(self.new_inputs, self.new_outputs)
# c for connection, stores the connection pattern of each variable
c_map = {}
num_of_input = len(fgraph.inputs)
# Initialize input connection pattern, each input affects itself
for index in range(num_of_input):
vec = [False] * num_of_input
vec[index] = True
# Make use of numpy.array to simplify codes
c_map.setdefault(fgraph.inputs[index], numpy.array(vec))
# Toposort the fgraph and get connection pattern for each variable
for node in fgraph.toposort():
# connection pattern of node's inputs.
in_vecs = []
for var in node.inputs:
if not isinstance(var, theano.Constant):
in_vecs.append(c_map[var])
if not hasattr(node.op, 'connection_pattern'):
# By default, nodes inputs affect all outputs
result = in_vecs[0]
for vec in in_vecs[1:]:
result |= vec
results = result * len(node.outputs)
else:
# If node's output connect to node's input, and that input
# connect to fgraph.input, that output connect to fgraph.input
# Therefore we use OR operation here.
results = []
out_vecs = numpy.array(node.op.connection_pattern(node))
for out_vec in out_vecs.T:
result = [False] * num_of_input
for in_vec, val in zip(in_vecs, out_vec):
result |= (in_vec & val)
results.append(result)
for var, result in zip(node.outputs, results):
c_map.setdefault(var, result)
# Transpose final result and convert pattern into python list
pattern = numpy.array([c_map[var] for var in fgraph.outputs]).T
return [list(vec) for vec in pattern]
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for # OpFromGraph doesn't implement a connection_pattern, so for
# now we regard all inputs and outputs as connected. This will # now we regard all inputs and outputs as connected. This will
......
...@@ -107,6 +107,19 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -107,6 +107,19 @@ class T_OpFromGraph(unittest.TestCase):
fn = function([x, y, z], f) fn = function([x, y, z], f)
assert numpy.allclose(15.0 + s.get_value(), assert numpy.allclose(15.0 + s.get_value(),
fn(xv, yv, zv)) fn(xv, yv, zv))
def test_connection_pattern(self):
import numpy
x, y, z = T.matrices('xyz')
out1 = x * y
out2 = y * z
op = OpFromGraph([x ,y, z], [out1, out2], moe='FAST_RUN')
results = op.connection_pattern(None)
expect_result = [[True, False],
[True, True],
[False, True]]
assert results == expect_result
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论