提交 843b6249 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Flake8 for compile/builders.py

上级 5d91204d
......@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op):
TODO:
- examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try gof.opt.is_same_graph_with_merge(op1.new_outputs, op2, new_outputs)
- __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.new_outputs, op2,
new_outputs)
- c_code() to remove the double overhead?
- opt to unfold it, work inplace on inputs
- grad() make it support DisconnectedType and the new interface
......@@ -68,7 +70,7 @@ class OpFromGraph(gof.Op):
for i in inputs + outputs:
if not isinstance(i, gof.Variable):
raise TypeError(
'inputs and outputs must be Variable instances', i)
'inputs and outputs must be Variable instances', i)
if 'updates' in kwargs:
raise TypeError('updates are not allowed in kwargs')
......@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op):
# not see them. Otherwise their is problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs)
if isinstance(var, SharedVariable)]
used_inputs = [var for var in gof.graph.inputs(outputs)
if not isinstance(var, gof.Constant)]
shared_vars = [var.type() for var in self.shared_inputs]
new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars,
replace=dict(zip(self.shared_inputs,
......@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op):
def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types):
if not type == input.type:
raise TypeError("Wrong type, expected %s but got %s"
% (type, input.type))
raise TypeError("Wrong type, expected %s but got %s" %
(type, input.type))
return gof.Apply(self,
list(inputs) + self.shared_inputs,
[type() for type in self.output_types])
......@@ -143,9 +143,10 @@ class OpFromGraph(gof.Op):
grad_ops = self.grad_ops
else:
gs = theano.gradient.grad(cost=None,
known_grads=dict(zip(self.new_outputs, output_grads)),
wrt=self.new_inputs,
disconnected_inputs='ignore')
known_grads=dict(zip(self.new_outputs,
output_grads)),
wrt=self.new_inputs,
disconnected_inputs='ignore')
grad_ops = []
for g in gs:
......
......@@ -38,7 +38,6 @@ whitelist_flake8 = [
"tests/test_tutorial.py",
"tests/disturb_mem.py",
"tests/unittest_tools.py",
"compile/builders.py",
"compile/__init__.py",
"compile/profiling.py",
"compile/function_module.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论