提交 64b126df authored 作者: Olivier Breuleux's avatar Olivier Breuleux

new klass!

上级 f7151839
...@@ -258,9 +258,9 @@ class Function(object): ...@@ -258,9 +258,9 @@ class Function(object):
# Check if inputs are missing or if inputs were set more than once # Check if inputs are missing or if inputs were set more than once
for c in self.input_storage: for c in self.input_storage:
if c.required and not c.provided: if c.required and not c.provided:
raise TypeError("Missing required input: %s" % self.inv_finder[c].result) raise TypeError("Missing required input: %s" % getattr(self.inv_finder[c], 'result', self.inv_finder[c]))
if c.provided > 1: if c.provided > 1:
raise TypeError("Multiple values for input: %s" % self.inv_finder[c].result) raise TypeError("Multiple values for input: %s" % getattr(self.inv_finder[c], 'result', self.inv_finder[c]))
# Do the actual work # Do the actual work
self.fn() self.fn()
outputs = [x.data for x in self.output_storage] outputs = [x.data for x in self.output_storage]
...@@ -351,6 +351,7 @@ class SanityCheckFunction(Function): ...@@ -351,6 +351,7 @@ class SanityCheckFunction(Function):
### FunctionMaker ### FunctionMaker
### ###
NODEFAULT = ['NODEFAULT']
class FunctionMaker(object): class FunctionMaker(object):
@staticmethod @staticmethod
...@@ -404,6 +405,7 @@ class FunctionMaker(object): ...@@ -404,6 +405,7 @@ class FunctionMaker(object):
in the graph from the inputs to the outputs in the graph from the inputs to the outputs
""" """
# Handle the case where inputs and/or outputs is a single Result (not in a list) # Handle the case where inputs and/or outputs is a single Result (not in a list)
unpack_single = False unpack_single = False
if not isinstance(outputs, (list, tuple)): if not isinstance(outputs, (list, tuple)):
...@@ -414,7 +416,7 @@ class FunctionMaker(object): ...@@ -414,7 +416,7 @@ class FunctionMaker(object):
# Wrap them in In or Out instances if needed. # Wrap them in In or Out instances if needed.
inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs) inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs)
_inputs = gof.graph.inputs([o.result for o in outputs]) _inputs = gof.graph.inputs([o.result for o in outputs] + [i.update for i in inputs if getattr(i, 'update', False)])
indices = [[input] + self.expand_in(input, _inputs) for input in inputs] indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], []) expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], [])
...@@ -482,7 +484,17 @@ class FunctionMaker(object): ...@@ -482,7 +484,17 @@ class FunctionMaker(object):
# one storage unit. The indices and subinputs lists represent which # one storage unit. The indices and subinputs lists represent which
# of the kit's inputs are active in this graph, so we make as many # of the kit's inputs are active in this graph, so we make as many
# storage units as needed # storage units as needed
input_storage += [[None] for i in indices] if isinstance(default, (list, tuple)) \
and all(isinstance(x, gof.Container) for x in default):
if len(default) == len(indices):
input_storage += [x.storage for x in default]
elif len(default) > len(indices):
input_storage += [default[i].storage for i in indices]
else:
raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default)
default = NODEFAULT
else:
input_storage += [[None] for i in indices]
else: else:
# Normal case: one new, independent storage unit # Normal case: one new, independent storage unit
input_storage.append([None]) input_storage.append([None])
...@@ -496,6 +508,8 @@ class FunctionMaker(object): ...@@ -496,6 +508,8 @@ class FunctionMaker(object):
# Even though a SymbolicInputKit represents more than one input, # Even though a SymbolicInputKit represents more than one input,
# we still only have one entry for the defaults list. # we still only have one entry for the defaults list.
if isinstance(input, SymbolicInputKit): if isinstance(input, SymbolicInputKit):
if default is NODEFAULT:
_defaults.append((False, False, None))
if default is None: if default is None:
_defaults.append((True, True, None)) _defaults.append((True, True, None))
else: else:
......
...@@ -99,6 +99,8 @@ class SymbolicInputKit(object): ...@@ -99,6 +99,8 @@ class SymbolicInputKit(object):
except ValueError: except ValueError:
pass pass
ret.sort() ret.sort()
if not ret:
return [[], []]
return zip(*ret) return zip(*ret)
......
差异被折叠。
...@@ -366,11 +366,11 @@ def tensor(*args, **kwargs): ...@@ -366,11 +366,11 @@ def tensor(*args, **kwargs):
def _multi(*fns): def _multi(*fns):
def f2(f, *names): def f2(f, *names):
if isinstance(names, int): if names and isinstance(names[0], int):
if names == 1: if names == 1:
return f() return f()
else: else:
return [f() for i in xrange(names)] return [f() for i in xrange(names[0])]
if isinstance(names, tuple): if isinstance(names, tuple):
if len(names) == 1: if len(names) == 1:
names = names[0] names = names[0]
...@@ -1537,7 +1537,7 @@ def get_vector_length(v): ...@@ -1537,7 +1537,7 @@ def get_vector_length(v):
raise TypeError('argument must be symbolic vector') raise TypeError('argument must be symbolic vector')
if isinstance(v, gof.Constant) and v.type.ndim == 1: if isinstance(v, gof.Constant) and v.type.ndim == 1:
return len(v.data) return len(v.data)
if v.owner and isinstance(v.owner.op, join): if v.owner and isinstance(v.owner.op, Join):
try: try:
return join.vec_length(v) return join.vec_length(v)
except: except:
......
...@@ -32,7 +32,10 @@ class RandomFunction(gof.Op): ...@@ -32,7 +32,10 @@ class RandomFunction(gof.Op):
out -> the random numbers we generated out -> the random numbers we generated
""" """
args = map(tensor.as_tensor, args) args = map(tensor.as_tensor, args)
shape = tensor.as_tensor(shape) if shape == () or shape == []:
shape = tensor.lvector()
else:
shape = tensor.as_tensor(shape)
assert shape.type == tensor.lvector assert shape.type == tensor.lvector
assert len(args) <= len(self.args) assert len(args) <= len(self.args)
args += (None,) * (len(self.args) - len(args)) args += (None,) * (len(self.args) - len(args))
...@@ -96,7 +99,10 @@ def random_function(fn, dtype, *rfargs, **rfkwargs): ...@@ -96,7 +99,10 @@ def random_function(fn, dtype, *rfargs, **rfkwargs):
r, shape, args = args[0], args[1], args[2:] r, shape, args = args[0], args[1], args[2:]
else: else:
r, shape, args = ndim, args[0], args[1:] r, shape, args = ndim, args[0], args[1:]
shape = tensor.as_tensor(shape) if shape == () or shape == []:
shape = tensor.TensorConstant(type = tensor.lvector, data = shape)
else:
shape = tensor.as_tensor(shape)
ndim = tensor.get_vector_length(shape) ndim = tensor.get_vector_length(shape)
if ndim is None: if ndim is None:
raise ValueError('Cannot infer the number of dimensions from the shape argument.') raise ValueError('Cannot infer the number of dimensions from the shape argument.')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论