提交 e832f70e authored 作者: James Bergstra's avatar James Bergstra

Added stricter checks to OpFromGraph constructor

上级 41678c83
...@@ -35,11 +35,21 @@ class OpFromGraph(gof.Op): ...@@ -35,11 +35,21 @@ class OpFromGraph(gof.Op):
""" """
def __init__(self, inputs, outputs, grad_depth = 1, **kwargs): def __init__(self, inputs, outputs, grad_depth = 1, **kwargs):
if not isinstance(outputs, list):
raise TypeError('outputs must be list', outputs)
for i in inputs + outputs:
if not isinstance(i, gof.Variable):
raise TypeError('inputs and outputs must be Variable instances', i)
if 'updates' in kwargs:
raise TypeError('updates are not allowed in kwargs')
# TODO: the graph may have implicit inputs like Value and SharedVariable instances.
# what impact to they have on the validity of this Op?
self.fn = function(inputs, outputs, **kwargs) self.fn = function(inputs, outputs, **kwargs)
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.input_types = [input.type for input in inputs] self.input_types = [input.type for input in inputs]
self.output_types = [output.type for output in outputs] self.output_types = [output.type for output in outputs]
if grad_depth > 0: if grad_depth > 0:
output_grads = [t() for t in self.output_types] output_grads = [t() for t in self.output_types]
gd = G.grad_sources_inputs(zip(self.outputs, output_grads), self.inputs) gd = G.grad_sources_inputs(zip(self.outputs, output_grads), self.inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论