提交 603d1792 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Overhaul of Scan.infer_shape, so it never uses variables from the inner function

上级 6e288d8a
......@@ -16,6 +16,7 @@ import copy
import itertools
import logging
import numpy
import sys
from theano.compile import SharedVariable, function, Param
from theano import compile
......@@ -574,42 +575,83 @@ class Scan(Op):
### Infer Shape
def infer_shape(self, node, input_shapes):
# input_shapes correspond to the shapes of node.inputs
# Here, we build a list inner_ins_shape, such that inner_ins_shape[i]
# is the shape of self.inputs[i]
# sequences
seqs_shape = [ x[1:] for x in input_shapes[1:1+self.n_seqs] ]
# mit_mot, mit_sot, sit_sot
n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outs_shape = []
for idx in xrange(n_outs):
for k in self.tap_array[idx]:
outs_shape += [ input_shapes[idx+self.n_seqs+1][1:] ]
# shared_outs
offset = 1 + self.n_seqs + n_outs
for idx in xrange(self.n_shared_outs):
outs_shape += [ input_shapes[idx+offset] ]
# non_sequences
offset += self.n_nit_sot + self.n_other_ignore + self.n_shared_outs
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(self.inputs)
# Non-sequences have a direct equivalent from self.inputs in node.inputs
inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):]
out_equivalent = {}
for in_ns, out_ns in zip(inner_non_sequences, input_shapes[offset:]):
out_equivalent[in_ns] = out_ns
outs_shape = scan_utils.infer_shape(
self.outputs
, self.inputs
, inner_ins_shapes)
outs = self.outputs,
inputs = self.inputs,
input_shapes = inner_ins_shapes)
# Will be used to check if outs_shape can be expressed without using
# variables in self.inputs
validator = scan_utils.Validator(
valid = [],
invalid = self.inputs,
valid_equivalent = out_equivalent)
offset = 1 + self.n_seqs
scan_outs = [x for x in input_shapes[offset:offset+n_outs]]
offset += n_outs
for x in xrange(self.n_nit_sot):
if outs_shape[n_outs+x] is not None:
scan_outs.append(
(node.inputs[offset+self.n_shared_outs+x],) +
tuple(outs_shape[n_outs+x]) )
out_shape_x = outs_shape[n_outs+x]
if out_shape_x is None:
# This output is not a tensor, and has no shape
scan_outs.append(None)
else:
# We need to make sure that we can compute the shapes from
# node.inputs, and constants, without using the variables
# in the inner function.
r = node.outputs[n_outs+x]
shp = (node.inputs[offset+self.n_shared_outs+x],)
shp += tuple([Shape_i(i)(r) for i in xrange(1,r.ndim)])
scan_outs.append( shp )
assert r.ndim == 1 + len(outs_shape[n_outs+x])
shp = [node.inputs[offset+self.n_shared_outs+x]]
for i, shp_i in zip(xrange(1,r.ndim), outs_shape[n_outs+x]):
# Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates
# whether variable is shp_i (if True), or an valid
# equivalent (if False). Here, we only need the variable.
v_shp_i = validator.check(shp_i)
if v_shp_i is None:
if hasattr(r, 'broadcastable') and r.broadcastable[i]:
shp.append(1)
else:
shp.append(Shape_i(i)(r))
else:
# It can (or at least, an equivalent variable can)
shp.append(v_shp_i[0])
scan_outs.append(tuple(shp))
scan_outs += [ x for x in
input_shapes[offset:offset+self.n_shared_outs] ]
return scan_outs
### GRAD FUNCTION
def grad(self, args, g_outs):
# 1. forward pass - get the outputs after applying scan
......
......@@ -575,45 +575,124 @@ def equal_computations(x,y, strict=False):
return False
return True
def infer_shape( outs, inputs, input_shapes):
def infer_shape(outs, inputs, input_shapes):
'''
Compute the shape of the outputs given the shape of the inputs
of a theano graph ( assuming that all ops on the way have infer_shape
implemented).
Compute the shape of the outputs given the shape of the inputs
of a theano graph.
'''
shape_dict = {}
for inp, inp_shp in zip(inputs, input_shapes):
shape_dict[inp] = inp_shp
# We use a ShapeFeature because it has all the necessary logic inside.
# We don't use the Feature interface, so we need to initialize some
# things by hand.
shape_feature = tensor.opt.ShapeFeature()
def local_traverse(out, shape_dict):
if out in shape_dict:
return shape_dict
elif not out.owner:
if isinstance(out, tensor.TensorConstant):
shape_dict[out] = out.data.shape
return shape_dict
elif isinstance(out, tensor.sharedvar.TensorSharedVariable):
shape_dict[out] = out.value.shape
return shape_dict
else:
raise ValueError('Could not figure shape of', out)
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
# All keys of shape_of should be either in valid or in invalid
shape_feature.shape_of = {}
# To avoid merging lots of ones together.
shape_feature.lscalar_one = tensor.constant(1, dtype='int64')
# Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes):
shape_feature.set_shape(inp, inp_shp)
def local_traverse(out):
'''
Go back in the graph, from out, adding computable shapes to shape_of.
'''
if out in shape_feature.shape_of:
# Its shape is already known
return
elif out.owner is None:
# This is an input of the graph
shape_feature.init_r(out)
else:
# Recurse over inputs
for inp in out.owner.inputs:
if not inp in shape_dict:
shape_dict = local_traverse(inp,shape_dict)
try:
self = out.owner.op
node = out.owner
input_shapes = [ shape_dict[i] for i in out.owner.inputs]
shapes = self.infer_shape(node, input_shapes)
out_idx = node.outputs.index(out)
shape_dict[out] = shapes[out_idx]
except:
shape_dict[out] = None
return shape_dict
for out in outs:
shape_dict = local_traverse(out, shape_dict)
return [ shape_dict[o] for o in outs]
if not inp in shape_feature.shape_of:
local_traverse(inp)
# shape_feature.on_import does not actually use an env
# It will call infer_shape and set_shape appropriately
dummy_env = None
shape_feature.on_import(dummy_env, out.owner)
ret = []
for o in outs:
local_traverse(o)
ret.append(shape_feature.shape_of[o])
return ret
class Validator(object):
def __init__(self, valid=[], invalid=[], valid_equivalent={}):
'''
Check if variables can be expressed without using variables in invalid.
init_valid_equivalent provides a dictionary mapping some invalid
variables to valid ones that can be used instead.
'''
# Nodes that are valid to have in the graph computing outputs
self.valid = set(valid)
# Nodes that are NOT valid to have in the graph computing outputs
self.invalid = set(invalid)
# Mapping from invalid variables to equivalent valid ones.
self.valid_equivalent = valid_equivalent.copy()
self.valid.update(valid_equivalent.values())
self.invalid.update(valid_equivalent.keys())
def check(self, out):
'''
Go backwards in the graph, from out, and check if out is valid.
If out is a valid node, (out, True) is returned.
If out is not valid, but has an equivalent e, (e, False) is returned.
If out is not valid and has no equivalent, None is returned.
'''
if out in self.valid:
return out, True
elif out in self.valid_equivalent:
return self.valid_equivalent[out], False
elif out in self.invalid:
return None
if out.owner is None:
# This is an unknown input node, so it is invalid.
self.invalid.add(out)
if isinstance(out, tensor.TensorConstant):
# We can clone it to get a valid constant
cloned_out = out.clone()
self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out
return cloned_out, False
return None
# Recurse over inputs
inputs = [self.check(i) for i in out.owner.inputs]
# If some inputs are invalid without equivalent, so is out
if None in inputs:
self.invalid.add(out)
return None
# If some inputs are invalid with equivalent,
# an equivalent out should be built and returned
all_inputs = [inp for (inp, is_valid) in inputs]
equiv_inputs = [inp for (inp, is_valid) in inputs if not is_valid]
if equiv_inputs:
cloned_node = out.owner.clone_with_new_inputs(all_inputs)
cloned_out = cloned_node.outputs[out.index]
self.invalid.add(out)
self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out
return cloned_out, False
# All inputs are valid, so is out
return out, True
def scan_can_remove_outs(op, out_idxs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论