提交 ed0e5b26 authored 作者: ChienliMa's avatar ChienliMa

Shape should be computed by outer_input

上级 427daa0d
......@@ -4,7 +4,7 @@ from theano.compat import izip
from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared
from theano.gof import ops_with_inner_function
from theano.gof.graph import io_connection_pattern
from theano.gof.graph import io_connection_pattern, clone_get_equiv
class OpFromGraph(gof.Op):
......@@ -142,9 +142,16 @@ class OpFromGraph(gof.Op):
return io_connection_pattern(self.new_inputs, self.new_outputs)
def infer_shape(self, node, shapes):
return theano.scan_module.scan_utils.infer_shape(self.new_outputs,
self.new_inputs,
shapes)
# clone fgraph
equiv = clone_get_equiv(self.new_inputs, self.new_outputs)
in_v = [equiv[var] for var in self.new_inputs]
out_v = [equiv[var] for var in self.new_outputs]
shape = theano.scan_module.scan_utils.infer_shape(out_v, in_v,
shapes)
replacement = dict([(ori, rpl) for ori, rpl in izip(in_v, node.inputs)])
return [theano.clone(shape[0], replace=replacement)]
def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for
......
......@@ -160,6 +160,3 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
[op_graph(x)],
[numpy.ones([3,4], dtype=config.floatX)],
OpFromGraph)
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论