提交 64b126df authored 作者: Olivier Breuleux's avatar Olivier Breuleux

new klass!

上级 f7151839
......@@ -258,9 +258,9 @@ class Function(object):
# Check if inputs are missing or if inputs were set more than once
for c in self.input_storage:
if c.required and not c.provided:
raise TypeError("Missing required input: %s" % self.inv_finder[c].result)
raise TypeError("Missing required input: %s" % getattr(self.inv_finder[c], 'result', self.inv_finder[c]))
if c.provided > 1:
raise TypeError("Multiple values for input: %s" % self.inv_finder[c].result)
raise TypeError("Multiple values for input: %s" % getattr(self.inv_finder[c], 'result', self.inv_finder[c]))
# Do the actual work
self.fn()
outputs = [x.data for x in self.output_storage]
......@@ -351,6 +351,7 @@ class SanityCheckFunction(Function):
### FunctionMaker
###
NODEFAULT = ['NODEFAULT']
class FunctionMaker(object):
@staticmethod
......@@ -404,6 +405,7 @@ class FunctionMaker(object):
in the graph from the inputs to the outputs
"""
# Handle the case where inputs and/or outputs is a single Result (not in a list)
unpack_single = False
if not isinstance(outputs, (list, tuple)):
......@@ -414,7 +416,7 @@ class FunctionMaker(object):
# Wrap them in In or Out instances if needed.
inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs)
_inputs = gof.graph.inputs([o.result for o in outputs])
_inputs = gof.graph.inputs([o.result for o in outputs] + [i.update for i in inputs if getattr(i, 'update', False)])
indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], [])
......@@ -482,7 +484,17 @@ class FunctionMaker(object):
# one storage unit. The indices and subinputs lists represent which
# of the kit's inputs are active in this graph, so we make as many
# storage units as needed
input_storage += [[None] for i in indices]
if isinstance(default, (list, tuple)) \
and all(isinstance(x, gof.Container) for x in default):
if len(default) == len(indices):
input_storage += [x.storage for x in default]
elif len(default) > len(indices):
input_storage += [default[i].storage for i in indices]
else:
raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default)
default = NODEFAULT
else:
input_storage += [[None] for i in indices]
else:
# Normal case: one new, independent storage unit
input_storage.append([None])
......@@ -496,6 +508,8 @@ class FunctionMaker(object):
# Even though a SymbolicInputKit represents more than one input,
# we still only have one entry for the defaults list.
if isinstance(input, SymbolicInputKit):
if default is NODEFAULT:
_defaults.append((False, False, None))
if default is None:
_defaults.append((True, True, None))
else:
......
......@@ -99,6 +99,8 @@ class SymbolicInputKit(object):
except ValueError:
pass
ret.sort()
if not ret:
return [[], []]
return zip(*ret)
......
差异被折叠。
......@@ -366,11 +366,11 @@ def tensor(*args, **kwargs):
def _multi(*fns):
def f2(f, *names):
if isinstance(names, int):
if names and isinstance(names[0], int):
if names == 1:
return f()
else:
return [f() for i in xrange(names)]
return [f() for i in xrange(names[0])]
if isinstance(names, tuple):
if len(names) == 1:
names = names[0]
......@@ -1537,7 +1537,7 @@ def get_vector_length(v):
raise TypeError('argument must be symbolic vector')
if isinstance(v, gof.Constant) and v.type.ndim == 1:
return len(v.data)
if v.owner and isinstance(v.owner.op, join):
if v.owner and isinstance(v.owner.op, Join):
try:
return join.vec_length(v)
except:
......
......@@ -32,7 +32,10 @@ class RandomFunction(gof.Op):
out -> the random numbers we generated
"""
args = map(tensor.as_tensor, args)
shape = tensor.as_tensor(shape)
if shape == () or shape == []:
shape = tensor.lvector()
else:
shape = tensor.as_tensor(shape)
assert shape.type == tensor.lvector
assert len(args) <= len(self.args)
args += (None,) * (len(self.args) - len(args))
......@@ -96,7 +99,10 @@ def random_function(fn, dtype, *rfargs, **rfkwargs):
r, shape, args = args[0], args[1], args[2:]
else:
r, shape, args = ndim, args[0], args[1:]
shape = tensor.as_tensor(shape)
if shape == () or shape == []:
shape = tensor.TensorConstant(type = tensor.lvector, data = shape)
else:
shape = tensor.as_tensor(shape)
ndim = tensor.get_vector_length(shape)
if ndim is None:
raise ValueError('Cannot infer the number of dimensions from the shape argument.')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论