提交 843b6249 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Flake8 for compile/builders.py

上级 5d91204d
...@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op): ...@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op):
TODO: TODO:
- examples for a multi-layer mlp. where? - examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try gof.opt.is_same_graph_with_merge(op1.new_outputs, op2, new_outputs) - __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.new_outputs, op2,
new_outputs)
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- opt to unfold it, work inplace on inputs - opt to unfold it, work inplace on inputs
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
...@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op): ...@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op):
# not see them. Otherwise their is problem with the gradient. # not see them. Otherwise their is problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs) self.shared_inputs = [var for var in gof.graph.inputs(outputs)
if isinstance(var, SharedVariable)] if isinstance(var, SharedVariable)]
used_inputs = [var for var in gof.graph.inputs(outputs)
if not isinstance(var, gof.Constant)]
shared_vars = [var.type() for var in self.shared_inputs] shared_vars = [var.type() for var in self.shared_inputs]
new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars, new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars,
replace=dict(zip(self.shared_inputs, replace=dict(zip(self.shared_inputs,
...@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op): ...@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op):
def make_node(self, *inputs): def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types): for input, type in zip(inputs, self.input_types):
if not type == input.type: if not type == input.type:
raise TypeError("Wrong type, expected %s but got %s" raise TypeError("Wrong type, expected %s but got %s" %
% (type, input.type)) (type, input.type))
return gof.Apply(self, return gof.Apply(self,
list(inputs) + self.shared_inputs, list(inputs) + self.shared_inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
...@@ -143,7 +143,8 @@ class OpFromGraph(gof.Op): ...@@ -143,7 +143,8 @@ class OpFromGraph(gof.Op):
grad_ops = self.grad_ops grad_ops = self.grad_ops
else: else:
gs = theano.gradient.grad(cost=None, gs = theano.gradient.grad(cost=None,
known_grads=dict(zip(self.new_outputs, output_grads)), known_grads=dict(zip(self.new_outputs,
output_grads)),
wrt=self.new_inputs, wrt=self.new_inputs,
disconnected_inputs='ignore') disconnected_inputs='ignore')
......
...@@ -38,7 +38,6 @@ whitelist_flake8 = [ ...@@ -38,7 +38,6 @@ whitelist_flake8 = [
"tests/test_tutorial.py", "tests/test_tutorial.py",
"tests/disturb_mem.py", "tests/disturb_mem.py",
"tests/unittest_tools.py", "tests/unittest_tools.py",
"compile/builders.py",
"compile/__init__.py", "compile/__init__.py",
"compile/profiling.py", "compile/profiling.py",
"compile/function_module.py", "compile/function_module.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论