提交 603d1792 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Overhaul of Scan.infer_shape, so it never uses variables from the inner function

上级 6e288d8a
...@@ -16,6 +16,7 @@ import copy ...@@ -16,6 +16,7 @@ import copy
import itertools import itertools
import logging import logging
import numpy import numpy
import sys
from theano.compile import SharedVariable, function, Param from theano.compile import SharedVariable, function, Param
from theano import compile from theano import compile
...@@ -574,42 +575,83 @@ class Scan(Op): ...@@ -574,42 +575,83 @@ class Scan(Op):
### Infer Shape ### Infer Shape
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
# input_shapes correspond to the shapes of node.inputs
# Here, we build a list inner_ins_shape, such that inner_ins_shape[i]
# is the shape of self.inputs[i]
# sequences
seqs_shape = [ x[1:] for x in input_shapes[1:1+self.n_seqs] ] seqs_shape = [ x[1:] for x in input_shapes[1:1+self.n_seqs] ]
# mit_mot, mit_sot, sit_sot
n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outs_shape = [] outs_shape = []
for idx in xrange(n_outs): for idx in xrange(n_outs):
for k in self.tap_array[idx]: for k in self.tap_array[idx]:
outs_shape += [ input_shapes[idx+self.n_seqs+1][1:] ] outs_shape += [ input_shapes[idx+self.n_seqs+1][1:] ]
# shared_outs
offset = 1 + self.n_seqs + n_outs offset = 1 + self.n_seqs + n_outs
for idx in xrange(self.n_shared_outs): for idx in xrange(self.n_shared_outs):
outs_shape += [ input_shapes[idx+offset] ] outs_shape += [ input_shapes[idx+offset] ]
# non_sequences
offset += self.n_nit_sot + self.n_other_ignore + self.n_shared_outs offset += self.n_nit_sot + self.n_other_ignore + self.n_shared_outs
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:] inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(self.inputs)
# Non-sequences have a direct equivalent from self.inputs in node.inputs
inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):]
out_equivalent = {}
for in_ns, out_ns in zip(inner_non_sequences, input_shapes[offset:]):
out_equivalent[in_ns] = out_ns
outs_shape = scan_utils.infer_shape( outs_shape = scan_utils.infer_shape(
self.outputs outs = self.outputs,
, self.inputs inputs = self.inputs,
, inner_ins_shapes) input_shapes = inner_ins_shapes)
# Will be used to check if outs_shape can be expressed without using
# variables in self.inputs
validator = scan_utils.Validator(
valid = [],
invalid = self.inputs,
valid_equivalent = out_equivalent)
offset = 1 + self.n_seqs offset = 1 + self.n_seqs
scan_outs = [x for x in input_shapes[offset:offset+n_outs]] scan_outs = [x for x in input_shapes[offset:offset+n_outs]]
offset += n_outs offset += n_outs
for x in xrange(self.n_nit_sot): for x in xrange(self.n_nit_sot):
if outs_shape[n_outs+x] is not None: out_shape_x = outs_shape[n_outs+x]
scan_outs.append( if out_shape_x is None:
(node.inputs[offset+self.n_shared_outs+x],) + # This output is not a tensor, and has no shape
tuple(outs_shape[n_outs+x]) ) scan_outs.append(None)
else: else:
# We need to make sure that we can compute the shapes from
# node.inputs, and constants, without using the variables
# in the inner function.
r = node.outputs[n_outs+x] r = node.outputs[n_outs+x]
shp = (node.inputs[offset+self.n_shared_outs+x],) assert r.ndim == 1 + len(outs_shape[n_outs+x])
shp += tuple([Shape_i(i)(r) for i in xrange(1,r.ndim)]) shp = [node.inputs[offset+self.n_shared_outs+x]]
scan_outs.append( shp ) for i, shp_i in zip(xrange(1,r.ndim), outs_shape[n_outs+x]):
# Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates
# whether variable is shp_i (if True), or an valid
# equivalent (if False). Here, we only need the variable.
v_shp_i = validator.check(shp_i)
if v_shp_i is None:
if hasattr(r, 'broadcastable') and r.broadcastable[i]:
shp.append(1)
else:
shp.append(Shape_i(i)(r))
else:
# It can (or at least, an equivalent variable can)
shp.append(v_shp_i[0])
scan_outs.append(tuple(shp))
scan_outs += [ x for x in scan_outs += [ x for x in
input_shapes[offset:offset+self.n_shared_outs] ] input_shapes[offset:offset+self.n_shared_outs] ]
return scan_outs return scan_outs
### GRAD FUNCTION ### GRAD FUNCTION
def grad(self, args, g_outs): def grad(self, args, g_outs):
# 1. forward pass - get the outputs after applying scan # 1. forward pass - get the outputs after applying scan
......
...@@ -575,45 +575,124 @@ def equal_computations(x,y, strict=False): ...@@ -575,45 +575,124 @@ def equal_computations(x,y, strict=False):
return False return False
return True return True
def infer_shape( outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
''' '''
Compute the shape of the outputs given the shape of the inputs Compute the shape of the outputs given the shape of the inputs
of a theano graph ( assuming that all ops on the way have infer_shape of a theano graph.
implemented).
''' '''
shape_dict = {} # We use a ShapeFeature because it has all the necessary logic inside.
for inp, inp_shp in zip(inputs, input_shapes): # We don't use the Feature interface, so we need to initialize some
shape_dict[inp] = inp_shp # things by hand.
shape_feature = tensor.opt.ShapeFeature()
def local_traverse(out, shape_dict): # Variable -> tuple(scalars) or None (All tensor vars map to tuple)
if out in shape_dict: # All keys of shape_of should be either in valid or in invalid
return shape_dict shape_feature.shape_of = {}
elif not out.owner:
if isinstance(out, tensor.TensorConstant): # To avoid merging lots of ones together.
shape_dict[out] = out.data.shape shape_feature.lscalar_one = tensor.constant(1, dtype='int64')
return shape_dict
elif isinstance(out, tensor.sharedvar.TensorSharedVariable): # Initialize shape_of with the input shapes
shape_dict[out] = out.value.shape for inp, inp_shp in zip(inputs, input_shapes):
return shape_dict shape_feature.set_shape(inp, inp_shp)
else:
raise ValueError('Could not figure shape of', out) def local_traverse(out):
'''
Go back in the graph, from out, adding computable shapes to shape_of.
'''
if out in shape_feature.shape_of:
# Its shape is already known
return
elif out.owner is None:
# This is an input of the graph
shape_feature.init_r(out)
else: else:
# Recurse over inputs
for inp in out.owner.inputs: for inp in out.owner.inputs:
if not inp in shape_dict: if not inp in shape_feature.shape_of:
shape_dict = local_traverse(inp,shape_dict) local_traverse(inp)
try:
self = out.owner.op # shape_feature.on_import does not actually use an env
node = out.owner # It will call infer_shape and set_shape appropriately
input_shapes = [ shape_dict[i] for i in out.owner.inputs] dummy_env = None
shapes = self.infer_shape(node, input_shapes) shape_feature.on_import(dummy_env, out.owner)
out_idx = node.outputs.index(out)
shape_dict[out] = shapes[out_idx] ret = []
except: for o in outs:
shape_dict[out] = None local_traverse(o)
return shape_dict ret.append(shape_feature.shape_of[o])
for out in outs: return ret
shape_dict = local_traverse(out, shape_dict)
return [ shape_dict[o] for o in outs] class Validator(object):
def __init__(self, valid=[], invalid=[], valid_equivalent={}):
'''
Check if variables can be expressed without using variables in invalid.
init_valid_equivalent provides a dictionary mapping some invalid
variables to valid ones that can be used instead.
'''
# Nodes that are valid to have in the graph computing outputs
self.valid = set(valid)
# Nodes that are NOT valid to have in the graph computing outputs
self.invalid = set(invalid)
# Mapping from invalid variables to equivalent valid ones.
self.valid_equivalent = valid_equivalent.copy()
self.valid.update(valid_equivalent.values())
self.invalid.update(valid_equivalent.keys())
def check(self, out):
'''
Go backwards in the graph, from out, and check if out is valid.
If out is a valid node, (out, True) is returned.
If out is not valid, but has an equivalent e, (e, False) is returned.
If out is not valid and has no equivalent, None is returned.
'''
if out in self.valid:
return out, True
elif out in self.valid_equivalent:
return self.valid_equivalent[out], False
elif out in self.invalid:
return None
if out.owner is None:
# This is an unknown input node, so it is invalid.
self.invalid.add(out)
if isinstance(out, tensor.TensorConstant):
# We can clone it to get a valid constant
cloned_out = out.clone()
self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out
return cloned_out, False
return None
# Recurse over inputs
inputs = [self.check(i) for i in out.owner.inputs]
# If some inputs are invalid without equivalent, so is out
if None in inputs:
self.invalid.add(out)
return None
# If some inputs are invalid with equivalent,
# an equivalent out should be built and returned
all_inputs = [inp for (inp, is_valid) in inputs]
equiv_inputs = [inp for (inp, is_valid) in inputs if not is_valid]
if equiv_inputs:
cloned_node = out.owner.clone_with_new_inputs(all_inputs)
cloned_out = cloned_node.outputs[out.index]
self.invalid.add(out)
self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out
return cloned_out, False
# All inputs are valid, so is out
return out, True
def scan_can_remove_outs(op, out_idxs): def scan_can_remove_outs(op, out_idxs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论