提交 fe5cee6e authored 作者: ChienliMa's avatar ChienliMa

Draft of OpFromGraph.connection_pattern and testcase. Test pass.

上级 25c208fd
import numpy
import theano
from theano import gof
from theano.compat import izip
from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared
from theano.gof import ops_with_inner_function
from theano.gof import FunctionGraph
class OpFromGraph(gof.Op):
......@@ -134,6 +136,59 @@ class OpFromGraph(gof.Op):
# we wont need this copy anymore
output[0] = variable.copy()
def connection_pattern(self, node):
"""
Connection_pattern is hard to calculate. In the function, we calculate
the transpose of connection_pattern, where M[output_index,input_index]
indicates whether input with index i affects output with index i.
At last we return the transpose of final result
"""
# or ori_inputs because user do not customize sharejvariable
fgraph = FunctionGraph(self.new_inputs, self.new_outputs)
# c for connection, stores the connection pattern of each variable
c_map = {}
num_of_input = len(fgraph.inputs)
# Initialize input connection pattern, each input affects itself
for index in range(num_of_input):
vec = [False] * num_of_input
vec[index] = True
# Make use of numpy.array to simplify codes
c_map.setdefault(fgraph.inputs[index], numpy.array(vec))
# Toposort the fgraph and get connection pattern for each variable
for node in fgraph.toposort():
# connection pattern of node's inputs.
in_vecs = []
for var in node.inputs:
if not isinstance(var, theano.Constant):
in_vecs.append(c_map[var])
if not hasattr(node.op, 'connection_pattern'):
# By default, nodes inputs affect all outputs
result = in_vecs[0]
for vec in in_vecs[1:]:
result |= vec
results = result * len(node.outputs)
else:
# If node's output connect to node's input, and that input
# connect to fgraph.input, that output connect to fgraph.input
# Therefore we use OR operation here.
results = []
out_vecs = numpy.array(node.op.connection_pattern(node))
for out_vec in out_vecs.T:
result = [False] * num_of_input
for in_vec, val in zip(in_vecs, out_vec):
result |= (in_vec & val)
results.append(result)
for var, result in zip(node.outputs, results):
c_map.setdefault(var, result)
# Transpose final result and convert pattern into python list
pattern = numpy.array([c_map[var] for var in fgraph.outputs]).T
return [list(vec) for vec in pattern]
def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for
# now we regard all inputs and outputs as connected. This will
......
......@@ -108,6 +108,19 @@ class T_OpFromGraph(unittest.TestCase):
assert numpy.allclose(15.0 + s.get_value(),
fn(xv, yv, zv))
def test_connection_pattern(self):
import numpy
x, y, z = T.matrices('xyz')
out1 = x * y
out2 = y * z
op = OpFromGraph([x ,y, z], [out1, out2], moe='FAST_RUN')
results = op.connection_pattern(None)
expect_result = [[True, False],
[True, True],
[False, True]]
assert results == expect_result
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论