提交 19dea460 authored 作者: khaotik's avatar khaotik

better compat with GPU

上级 da7f2835
...@@ -340,11 +340,10 @@ class OpFromGraph(gof.Op): ...@@ -340,11 +340,10 @@ class OpFromGraph(gof.Op):
return self._grad_op(*(list(inputs) + list(output_grads)), return_list=True) return self._grad_op(*(list(inputs) + list(output_grads)), return_list=True)
def make_node(self, *inputs): def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types): num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
if not type == input.type: if len(inputs) != num_expected_inps:
raise TypeError("Wrong type, expected %s but got %s" % raise ValueError("Expected %d inputs, got %d" % (num_expected_inps, len(inputs)))
(type, input.type)) inputs = [inp_t.filter_variable(inp) for inp, inp_t in izip(inputs, self.input_types)]
apply_node = gof.Apply( apply_node = gof.Apply(
self, list(inputs) + self.shared_inputs, self, list(inputs) + self.shared_inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论