提交 7227f667 authored 作者: ChienliMa's avatar ChienliMa

add missing changes

上级 3673950e
......@@ -145,35 +145,9 @@ class OpFromGraph(gof.Op):
return io_connection_pattern(self.new_inputs, self.new_outputs)
def infer_shape(self, node, shapes):
fg = FunctionGraph(self.new_inputs, self.new_outputs)
order = fg.toposort()
# A dict that map variable to its shape
shape_map = {}
# set the input shapes of the fgraph
for in_var, shape in zip(fg.inputs, shapes):
shape_map.setdefault(in_var, shape)
# calculate the output shape from input shape
for node in order:
# deal with constant
for var in node.inputs:
if isinstance(var, theano.Constant):
shape_map.setdefault(var, var.shape)
# assert we already have the shape of necessary inputs
assert all([var in shape_map.keys() for var in node.inputs])
# calculate output shape
in_shapes = [shape_map[var] for var in node.inputs]
out_shapes = node.op.infer_shape(node, in_shapes)
# store the shape of that variable
for out_var, shape in zip(node.outputs, out_shapes):
shape_map.setdefault(out_var, shape)
# extract output shape
return tuple([shape_map[var] for var in fg.outputs])
return theano.scan_module.scan_utils.infer_shape(self.new_outputs,
self.new_inputs,
shapes)
def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论