提交 9a61e170 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4619 from nke001/ccw3398_cpy

ccw3398
......@@ -25,7 +25,7 @@ from theano.gof import (graph, utils, link, ops_with_inner_function)
from theano.gof.link import raise_with_op
from theano.compile.function_module import (
FunctionMaker, Function, infer_reuse_pattern,
SymbolicInputKit, SymbolicOutput, Supervisor, std_fgraph)
SymbolicOutput, Supervisor, std_fgraph)
from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard
......@@ -2517,27 +2517,9 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# default.storage to input_storage.
if indices is not None:
raise TypeError("Cannot take a Container instance as "
"default for a SymbolicInputKit.")
"default for a SymbolicInput.")
input_storage.append(default.storage)
default = None
elif isinstance(input, SymbolicInputKit):
# If the input is a SymbolicInputKit, it represents more than
# one storage unit. The indices and subinputs lists represent
# which of the kit's inputs are active in this graph, so we
# make as many storage units as needed
if isinstance(default, (list, tuple)) \
and all(isinstance(x, gof.Container) for x in default):
if len(default) == len(indices):
input_storage += [x.storage for x in default]
elif len(default) > len(indices):
input_storage += [default[i].storage for i in indices]
else:
raise ValueError(
'Not enough storage for SymbolicInputKit',
input, indices, default)
default = _NODEFAULT
else:
input_storage += [[None] for i in indices]
else:
# Normal case: one new, independent storage unit
input_storage.append([None])
......@@ -2550,16 +2532,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# storage after each function call
# - value is the value that will be put in the storage initially
# Even though a SymbolicInputKit represents more than one input,
# we still only have one entry for the defaults list.
if isinstance(input, SymbolicInputKit):
if default is _NODEFAULT:
_defaults.append((False, False, None))
elif default is None:
_defaults.append((True, True, None))
else:
_defaults.append((False, False, default))
elif input.update is not None:
if input.update is not None:
# If the input has an update, then (logically) it is
# not required since it is just a parameter and of
# course we don't want to refeed the default back into
......
......@@ -16,12 +16,11 @@ import numpy
import theano
from theano import config, gof
from functools import partial
from theano.compat import izip
from theano.gof import graph
import theano.compile.mode
from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
In, SymbolicInput, SymbolicOutput)
from theano.compile.ops import deep_copy_op, view_op
from theano.gof.graph import is_same_graph
from theano.gof.op import ops_with_inner_function
......@@ -286,7 +285,7 @@ class Function(object):
indices = None
"""
List of (SymbolicInput|SymbolicInputKit, indices, [SymbolicInput,...]),
List of (SymbolicInput, indices, [SymbolicInput,...]),
one tuple for each input.
The first tuple element is the SymbolicInput object for the corresponding
......@@ -396,7 +395,6 @@ class Function(object):
# self.input_storage inplace.
for i, ((input, indices, sinputs), (required, refeed, value)) in \
enumerate(zip(self.indices, defaults)):
# this is true iff input is not a SymbolicInputKit
if indices is None:
# containers is being used as a stack. Here we pop off
# the next one.
......@@ -432,41 +430,6 @@ class Function(object):
named_inputs.append(input.name)
inv_finder[c] = input
containers[:1] = []
else:
# TODO The following code may need to do something to handle
# implicit inputs.
# The input is a SymbolicInputKit, so we take as many
# containers as the Kit provides inputs
cs = containers[:len(indices)]
# distribute does the initialization of the containers
input.distribute(value, indices, cs)
f = partial(distribute, indices, cs)
# Like before, we set a finder entry for the kit. Note that
# we are not mapping to a container but to a function which
# can reinitialize all the containers
finder[i] = f
finder[input] = f
if input.name not in finder:
finder[input.name] = f
else:
finder[input.name] = DUPLICATE
# For each input in the kit and its corresponding
# container, we put an entry in finder. This allows
# the user to micro-manage elements of the kit if need
# be. All containers inherit the required field and
# have their own "provided" counter
for c, sin in zip(cs, sinputs):
finder[sin.variable] = c
finder[sin.name] = c
if sin.name not in finder:
finder[sin.name] = c
else:
finder[sin.name] = DUPLICATE
inv_finder[c] = input
c.required = required
c.provided = 0
containers[:len(indices)] = []
self.finder = finder
self.inv_finder = inv_finder
......@@ -1033,16 +996,8 @@ def _pickle_Function(f):
for (input, indices, inputs), (required, refeed, default) in \
zip(f.indices, f.defaults):
if isinstance(input, SymbolicInputKit):
li = len(indices)
if not default:
input_storage.append(ins[:li])
else:
input_storage.append(default)
ins[:li] = []
else:
input_storage.append(ins[0])
del ins[0]
input_storage.append(ins[0])
del ins[0]
inputs_data = [x.data for x in f.input_storage]
......@@ -1210,7 +1165,7 @@ class FunctionMaker(object):
@staticmethod
def wrap_in(input):
if isinstance(input, (SymbolicInput, SymbolicInputKit)):
if isinstance(input, (SymbolicInput)):
return input
elif isinstance(input, gof.Variable):
# r -> SymbolicInput(variable=r)
......@@ -1234,9 +1189,10 @@ class FunctionMaker(object):
# instances in inputs. For SymbolicInput, this returns None
# as the list of indices and a list with just the
# SymbolicInput.
if isinstance(sinput, SymbolicInputKit):
return sinput.complete(rinputs)
elif isinstance(sinput, SymbolicInput):
# if isinstance(sinput, SymbolicInputKit):
# return sinput.complete(rinputs)
# elif isinstance(sinput, SymbolicInput):
if isinstance(sinput, SymbolicInput):
return [None, [sinput]]
@staticmethod
......@@ -1858,7 +1814,7 @@ def convert_function_input(input):
`In`(r, name=name, value=val, update=up, autoname=True)
"""
if isinstance(input, (SymbolicInput, SymbolicInputKit)):
if isinstance(input, SymbolicInput):
return input
elif isinstance(input, gof.Constant):
raise TypeError('A Constant instance is not a legal function input',
......@@ -1887,7 +1843,7 @@ def convert_function_input(input):
else:
raise TypeError("Invalid input syntax: %s (check "
"documentation or use an In instance)" % orig)
elif isinstance(input[0], (SymbolicInput, SymbolicInputKit)):
elif isinstance(input[0], SymbolicInput):
if len(input) == 1:
return input[0]
elif len(input) == 2:
......
......@@ -97,69 +97,6 @@ class SymbolicInput(object):
return str(self)
# TODO: FB: I think this isn't used, confirm this and remove.
class SymbolicInputKit(object):
"""
Represents a group ("kit") of SymbolicInputs. If fed into function or
FunctionMaker, only the inputs which are needed to compile the function
properly will be taken.
A SymbolicInputKit provides the distribute function in order to set or
initialize several inputs from a single value. Specialized Kits should
override it.
"""
def __init__(self, name):
if not isinstance(name, string_types):
raise TypeError('name must be a string (got: %s)' % name)
self.name = name
self.sinputs = []
self.variables = []
def add_input(self, sinput):
"""
Add a SymbolicInput to this SymbolicInputKit.
It will be given the next available index.
"""
self.sinputs.append(sinput)
self.variables.append(sinput.variable)
def distribute(self, value, indices, containers):
"""
Given a list of indices corresponding to SymbolicInputs in this kit
as well as a corresponding list of containers, initialize all the
containers using the provided value.
"""
raise NotImplementedError
def complete(self, inputs):
"""
Given inputs (a list of Variable instances), checks through all the
SymbolicInputs in the kit and return a sorted list of indices and a list
of their corresponding SymbolicInputs such that each of them represents
some variable in the inputs list.
Not all the provided inputs will have a corresponding SymbolicInput in
the kit.
"""
ret = []
for input in inputs:
try:
i = self.variables.index(input)
ret.append((i, self.sinputs[i]))
except ValueError:
pass
ret.sort()
if not ret:
return [[], []]
return list(zip(*ret))
class In(SymbolicInput):
"""
Represents a symbolic input for use with function or FunctionMaker.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论