提交 64952355 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use direct imports from theano.gof.graph in theano.link.c.basic

上级 ebd59050
...@@ -12,9 +12,10 @@ from io import StringIO ...@@ -12,9 +12,10 @@ from io import StringIO
import numpy as np import numpy as np
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof import graph
from theano.gof.callcache import CallCache from theano.gof.callcache import CallCache
from theano.gof.compilelock import get_lock, release_lock from theano.gof.compilelock import get_lock, release_lock
from theano.gof.graph import Constant, NoParams, io_toposort
from theano.gof.graph import variables as get_variables
from theano.gof.utils import MethodNotDefined, difference, uniq from theano.gof.utils import MethodNotDefined, difference, uniq
from theano.link.basic import Container, Linker, LocalLinker, PerformLinker from theano.link.basic import Container, Linker, LocalLinker, PerformLinker
from theano.link.c.cmodule import ( from theano.link.c.cmodule import (
...@@ -635,14 +636,14 @@ class CLinker(Linker): ...@@ -635,14 +636,14 @@ class CLinker(Linker):
# We need to include the unused inputs in our variables, # We need to include the unused inputs in our variables,
# otherwise we can't pass them to the module. # otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(fgraph.clients[var])] self.variables = [var for var in self.inputs if not len(fgraph.clients[var])]
self.variables += graph.variables(self.inputs, self.outputs) self.variables += get_variables(self.inputs, self.outputs)
# This adds a hidden input which is the params for each node # This adds a hidden input which is the params for each node
# that needs it # that needs it
self.node_params = dict() self.node_params = dict()
for node in self.node_order: for node in self.node_order:
params = node.run_params() params = node.run_params()
if params is not graph.NoParams: if params is not NoParams:
# try to avoid creating more than one variable for the # try to avoid creating more than one variable for the
# same params. # same params.
if params in self.node_params: if params in self.node_params:
...@@ -650,7 +651,7 @@ class CLinker(Linker): ...@@ -650,7 +651,7 @@ class CLinker(Linker):
assert var.type == node.params_type assert var.type == node.params_type
fgraph.clients[var].append((node, "params")) fgraph.clients[var].append((node, "params"))
else: else:
var = graph.Constant(node.params_type, params) var = Constant(node.params_type, params)
fgraph.clients[var] = [(node, "params")] fgraph.clients[var] = [(node, "params")]
self.node_params[params] = var self.node_params[params] = var
self.variables.append(var) self.variables.append(var)
...@@ -660,13 +661,13 @@ class CLinker(Linker): ...@@ -660,13 +661,13 @@ class CLinker(Linker):
self.orphans = list( self.orphans = list(
r r
for r in self.variables for r in self.variables
if isinstance(r, graph.Constant) and r not in self.inputs if isinstance(r, Constant) and r not in self.inputs
) )
# C type constants (theano.scalar.Scalar). They don't request an object # C type constants (theano.scalar.Scalar). They don't request an object
self.consts = [] self.consts = []
# Move c type from orphans (theano.scalar.Scalar) to self.consts # Move c type from orphans (theano.scalar.Scalar) to self.consts
for variable in self.orphans: for variable in self.orphans:
if isinstance(variable, graph.Constant): if isinstance(variable, Constant):
try: try:
variable.type.c_literal(variable.data) variable.type.c_literal(variable.data)
self.consts.append(variable) self.consts.append(variable)
...@@ -742,7 +743,7 @@ class CLinker(Linker): ...@@ -742,7 +743,7 @@ class CLinker(Linker):
[get_c_declare, get_c_extract, get_c_cleanup], [get_c_declare, get_c_extract, get_c_cleanup],
] ]
elif variable in self.orphans: elif variable in self.orphans:
if not isinstance(variable, graph.Constant): if not isinstance(variable, Constant):
raise TypeError( raise TypeError(
"All orphans to CLinker must be Constant" " instances.", "All orphans to CLinker must be Constant" " instances.",
variable, variable,
...@@ -818,7 +819,7 @@ class CLinker(Linker): ...@@ -818,7 +819,7 @@ class CLinker(Linker):
sub = dict(failure_var=failure_var) sub = dict(failure_var=failure_var)
params = node.run_params() params = node.run_params()
if params is not graph.NoParams: if params is not NoParams:
params_var = symbol[self.node_params[params]] params_var = symbol[self.node_params[params]]
# The placeholder will be replaced by a hash of the entire # The placeholder will be replaced by a hash of the entire
...@@ -833,13 +834,13 @@ class CLinker(Linker): ...@@ -833,13 +834,13 @@ class CLinker(Linker):
# Make the CodeBlock for c_code # Make the CodeBlock for c_code
sub["id"] = id sub["id"] = id
sub["fail"] = failure_code(sub) sub["fail"] = failure_code(sub)
if params is not graph.NoParams: if params is not NoParams:
sub["params"] = params_var sub["params"] = params_var
sub_struct = dict() sub_struct = dict()
sub_struct["id"] = id + 1 sub_struct["id"] = id + 1
sub_struct["fail"] = failure_code_init(sub) sub_struct["fail"] = failure_code_init(sub)
if params is not graph.NoParams: if params is not NoParams:
# Since params inputs are always constants they are # Since params inputs are always constants they are
# guaranteed to be available in the struct init code. # guaranteed to be available in the struct init code.
sub_struct["params"] = params_var sub_struct["params"] = params_var
...@@ -1399,7 +1400,7 @@ class CLinker(Linker): ...@@ -1399,7 +1400,7 @@ class CLinker(Linker):
# doesn't need to include information about inplace operations # doesn't need to include information about inplace operations
# because that information will be included explicitly in # because that information will be included explicitly in
# cmodule_key_(). # cmodule_key_().
return graph.io_toposort(self.inputs, self.outputs) return io_toposort(self.inputs, self.outputs)
fgraph = FakeFunctionGraph(inputs, outputs) fgraph = FakeFunctionGraph(inputs, outputs)
return self.cmodule_key_( return self.cmodule_key_(
...@@ -1495,7 +1496,7 @@ class CLinker(Linker): ...@@ -1495,7 +1496,7 @@ class CLinker(Linker):
# It is important that a variable (i) # It is important that a variable (i)
# yield a 'position' that reflects its role in code_gen() # yield a 'position' that reflects its role in code_gen()
if isinstance(i, graph.Constant): # orphans if isinstance(i, Constant): # orphans
if id(i) not in constant_ids: if id(i) not in constant_ids:
isig = (i.signature(), topological_pos, i_idx) isig = (i.signature(), topological_pos, i_idx)
# If the Theano constant provides a strong hash # If the Theano constant provides a strong hash
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论