提交 d5f65500 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use direct theano.gof imports in theano.scan.op

上级 9c445e0a
...@@ -53,23 +53,27 @@ from collections import OrderedDict ...@@ -53,23 +53,27 @@ from collections import OrderedDict
import numpy as np import numpy as np
import theano import theano
from theano import compile, gof, gradient, tensor from theano import tensor
from theano.compile.builders import infer_shape from theano.compile.builders import infer_shape
from theano.compile.function import function from theano.compile.function import function
from theano.compile.io import In, Out from theano.compile.io import In, Out
from theano.compile.mode import AddFeatureOptimizer from theano.compile.mode import AddFeatureOptimizer, get_mode
from theano.compile.profiling import ScanProfileStats from theano.compile.profiling import ScanProfileStats, register_profiler_printer
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof import Apply, Op from theano.gof.fg import MissingInputError
from theano.gof.graph import equal_computations, io_connection_pattern from theano.gof.graph import Apply, Variable, equal_computations
from theano.gof.graph import inputs as graph_inputs
from theano.gof.graph import io_connection_pattern
from theano.gof.op import Op, ops_with_inner_function
from theano.gof.toolbox import NoOutputFromInplace from theano.gof.toolbox import NoOutputFromInplace
from theano.gradient import DisconnectedType, NullType, grad_undefined from theano.gradient import DisconnectedType, NullType, grad, grad_undefined
from theano.link.c.basic import CLinker from theano.link.c.basic import CLinker
from theano.link.c.exceptions import MissingGXX from theano.link.c.exceptions import MissingGXX
from theano.link.utils import raise_with_op from theano.link.utils import raise_with_op
from theano.scan.utils import Validator, forced_replace, hash_listsDictsTuples, safe_new from theano.scan.utils import Validator, forced_replace, hash_listsDictsTuples, safe_new
from theano.tensor import TensorType, as_tensor_variable from theano.tensor.basic import as_tensor_variable
from theano.tensor.opt import Shape_i from theano.tensor.opt import Shape_i
from theano.tensor.type import TensorType
__docformat__ = "restructedtext en" __docformat__ = "restructedtext en"
...@@ -169,7 +173,7 @@ class Scan(Op): ...@@ -169,7 +173,7 @@ class Scan(Op):
if self.as_while: if self.as_while:
self.output_types = self.output_types[:-1] self.output_types = self.output_types[:-1]
mode_instance = compile.mode.get_mode(self.mode) mode_instance = get_mode(self.mode)
# Clone mode_instance, altering "allow_gc" for the linker, # Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile # and adding a message if we profile
if self.name: if self.name:
...@@ -202,11 +206,9 @@ class Scan(Op): ...@@ -202,11 +206,9 @@ class Scan(Op):
self._hash_inner_graph = self.info["gpu_hash"] self._hash_inner_graph = self.info["gpu_hash"]
else: else:
# Do the missing inputs check here to have the error early. # Do the missing inputs check here to have the error early.
for var in theano.gof.graph.inputs(self.outputs, self.inputs): for var in graph_inputs(self.outputs, self.inputs):
if var not in self.inputs and not isinstance(var, theano.Constant): if var not in self.inputs and not isinstance(var, theano.Constant):
raise theano.gof.MissingInputError( raise MissingInputError(f"ScanOp is missing an input: {repr(var)}")
f"ScanOp is missing an input: {repr(var)}"
)
self._cmodule_key = CLinker().cmodule_key_variables( self._cmodule_key = CLinker().cmodule_key_variables(
self.inputs, self.outputs, [] self.inputs, self.outputs, []
) )
...@@ -317,7 +319,7 @@ class Scan(Op): ...@@ -317,7 +319,7 @@ class Scan(Op):
the inner function) the inner function)
""" """
assert np.all(isinstance(i, gof.Variable) for i in inputs) assert np.all(isinstance(i, Variable) for i in inputs)
# Check that the number of inputs to the Scan node corresponds to # Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan # the number of inputs of the inner function of scan
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1 n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
...@@ -2173,7 +2175,7 @@ class Scan(Op): ...@@ -2173,7 +2175,7 @@ class Scan(Op):
wrt = [ wrt = [
x x
for x in theano.gof.graph.inputs(y_s) for x in graph_inputs(y_s)
if (x in diff_inputs) if (x in diff_inputs)
and get_inp_idx(self_inputs.index(x)) in connected_inputs and get_inp_idx(self_inputs.index(x)) in connected_inputs
] ]
...@@ -2188,7 +2190,7 @@ class Scan(Op): ...@@ -2188,7 +2190,7 @@ class Scan(Op):
# to X. # to X.
known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()]) known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()])
grads = gradient.grad( grads = grad(
cost=None, cost=None,
known_grads=known_grads, known_grads=known_grads,
wrt=wrt, wrt=wrt,
...@@ -2238,7 +2240,7 @@ class Scan(Op): ...@@ -2238,7 +2240,7 @@ class Scan(Op):
) )
for pos, inp in enumerate(states): for pos, inp in enumerate(states):
if inp in theano.gof.graph.inputs([Xt]): if inp in graph_inputs([Xt]):
# Get the index of the outer output that to which # Get the index of the outer output that to which
# the state variable 'inp' corresponds. # the state variable 'inp' corresponds.
outer_oidx = self.var_mappings["outer_out_from_inner_inp"][ outer_oidx = self.var_mappings["outer_out_from_inner_inp"][
...@@ -2456,7 +2458,7 @@ class Scan(Op): ...@@ -2456,7 +2458,7 @@ class Scan(Op):
disconnected = False disconnected = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in graph_inputs([dC_dinps_t[ins_pos]]):
through_shared = True through_shared = True
ins_pos += 1 ins_pos += 1
...@@ -2511,7 +2513,7 @@ class Scan(Op): ...@@ -2511,7 +2513,7 @@ class Scan(Op):
if not disconnected_dC_dinps_t[ins_pos]: if not disconnected_dC_dinps_t[ins_pos]:
disconnected = False disconnected = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in graph_inputs([dC_dinps_t[ins_pos]]):
through_shared = True through_shared = True
n_mitmot_inps += 1 n_mitmot_inps += 1
...@@ -2559,7 +2561,7 @@ class Scan(Op): ...@@ -2559,7 +2561,7 @@ class Scan(Op):
inner_out_mitmot.append(dC_dinps_t[ins_pos]) inner_out_mitmot.append(dC_dinps_t[ins_pos])
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in graph_inputs([dC_dinps_t[ins_pos]]):
through_shared = True through_shared = True
if isinstance(dC_dinps_t[ins_pos].type, NullType): if isinstance(dC_dinps_t[ins_pos].type, NullType):
...@@ -2583,7 +2585,7 @@ class Scan(Op): ...@@ -2583,7 +2585,7 @@ class Scan(Op):
for _p, vl in enumerate(inner_out_sitsot): for _p, vl in enumerate(inner_out_sitsot):
through_shared = False through_shared = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]): if _sh in graph_inputs([vl]):
through_shared = True through_shared = True
if isinstance(vl.type, NullType): if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null) type_outs.append(vl.type.why_null)
...@@ -2602,7 +2604,7 @@ class Scan(Op): ...@@ -2602,7 +2604,7 @@ class Scan(Op):
for _p, vl in enumerate(inner_out_nitsot): for _p, vl in enumerate(inner_out_nitsot):
through_shared = False through_shared = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]): if _sh in graph_inputs([vl]):
through_shared = True through_shared = True
if isinstance(vl.type, NullType): if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null) type_outs.append(vl.type.why_null)
...@@ -3043,10 +3045,10 @@ class Scan(Op): ...@@ -3043,10 +3045,10 @@ class Scan(Op):
# Since Scan is an op that contains a Theano compiled function, it is # Since Scan is an op that contains a Theano compiled function, it is
# useful to let DebugMode know about it. # useful to let DebugMode know about it.
gof.ops_with_inner_function[Scan] = "fn" ops_with_inner_function[Scan] = "fn"
@theano.compile.profiling.register_profiler_printer @register_profiler_printer
def profile_printer( def profile_printer(
message, compile_time, fct_call_time, apply_time, apply_cimpl, outputs_size, file message, compile_time, fct_call_time, apply_time, apply_cimpl, outputs_size, file
): ):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论