提交 e1695582 authored 作者: ChienliMa's avatar ChienliMa

Another implementation without error

上级 a51c6ec4
...@@ -146,15 +146,35 @@ class OpFromGraph(gof.Op): ...@@ -146,15 +146,35 @@ class OpFromGraph(gof.Op):
return io_connection_pattern(self.new_inputs, self.new_outputs) return io_connection_pattern(self.new_inputs, self.new_outputs)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
# Construct a new fgraph fg = FunctionGraph(self.new_inputs, self.new_outputs)
equiv = graph.clone_get_equiv(self.new_inputs, self.new_outputs) order = fg.toposort()
replacement = dict((equiv[var], var) for var in self.new_inputs)
out_cpy = theano.clone([equiv[var] for var in self.new_outputs], replace=replacement) # A dict that map variable to its shape
fg = FunctionGraph(self.new_inputs, out_cpy) shape_map = {}
in_shapes = dict([(v, shp) for v, shp in zip(fg.inputs, shapes)]) # set the input shapes of the fgraph
all_shapes = theano.tensor.utils.shape_of_variables(fg, in_shapes) for in_var, shape in zip(fg.inputs, shapes):
return [all_shapes[var] for var in fg.outputs] shape_map.setdefault(in_var, shape)
# calculate the output shape from input shape
for node in order:
# deal with constant
for var in node.inputs:
if isinstance(var, theano.Constant):
shape_map.setdefault(var, var.shape)
# assert we already have the shape of necessary inputs
assert all([var in shape_map.keys() for var in node.inputs])
# calculate output shape
in_shapes = [ shape_map[var] for var in node.inputs]
out_shapes = node.op.infer_shape(node, in_shapes)
# store the shape of that variable
for out_var, shape in zip(node.outputs, out_shapes):
shape_map.setdefault(out_var, shape)
# extract output shape
return tuple([shape_map[var] for var in fg.outputs])
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for # OpFromGraph doesn't implement a connection_pattern, so for
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论