提交 971da4ee authored 作者: Razvan Pascanu's avatar Razvan Pascanu

removed useless class

Before this class was used to validate that computing the shape of an argument does not rely on any input of the inner function of scan. We do not need to do this anymore, since now the output storage size is given as input to scan, and output shapes are the same as input shape (so the internal graph of scan does not come into play anymore).
上级 3accfc38
......@@ -316,77 +316,6 @@ def infer_shape(outs, inputs, input_shapes):
return ret
class Validator(object):
def __init__(self, valid=[], invalid=[], valid_equivalent={}):
'''
Check if variables can be expressed without using variables in invalid.
init_valid_equivalent provides a dictionary mapping some invalid
variables to valid ones that can be used instead.
'''
# Nodes that are valid to have in the graph computing outputs
self.valid = set(valid)
# Nodes that are NOT valid to have in the graph computing outputs
self.invalid = set(invalid)
# Mapping from invalid variables to equivalent valid ones.
self.valid_equivalent = valid_equivalent.copy()
self.valid.update(valid_equivalent.values())
self.invalid.update(valid_equivalent.keys())
def check(self, out):
'''
Go backwards in the graph, from out, and check if out is valid.
If out is a valid node, (out, True) is returned.
If out is not valid, but has an equivalent e, (e, False) is returned.
If out is not valid and has no equivalent, None is returned.
'''
if out in self.valid:
return out, True
elif out in self.valid_equivalent:
return self.valid_equivalent[out], False
elif out in self.invalid:
return None
if out.owner is None:
# This is an unknown input node, so it is invalid.
self.invalid.add(out)
if isinstance(out, tensor.TensorConstant):
# We can clone it to get a valid constant
cloned_out = out.clone()
self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out
return cloned_out, False
return None
# Recurse over inputs
inputs = [self.check(i) for i in out.owner.inputs]
# If some inputs are invalid without equivalent, so is out
if None in inputs:
self.invalid.add(out)
return None
# If some inputs are invalid with equivalent,
# an equivalent out should be built and returned
all_inputs = [inp for (inp, is_valid) in inputs]
equiv_inputs = [inp for (inp, is_valid) in inputs if not is_valid]
if equiv_inputs:
cloned_node = out.owner.clone_with_new_inputs(all_inputs)
cloned_out = cloned_node.outputs[out.index]
self.invalid.add(out)
self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out
return cloned_out, False
# All inputs are valid, so is out
return out, True
def allocate_memory(T, y_info, y):
"""
Allocates memory for an output of scan.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论