提交 233feaf5 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Refactor use of some package-level imports from theano.link

上级 27448332
...@@ -6,7 +6,7 @@ from theano.gof import fg ...@@ -6,7 +6,7 @@ from theano.gof import fg
from theano.gof.graph import Apply, Constant, Variable from theano.gof.graph import Apply, Constant, Variable
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
from theano.link import PerformLinker from theano.link.basic import PerformLinker
from theano.link.c.cc import CLinker, DualLinker, OpWiseCLinker from theano.link.c.cc import CLinker, DualLinker, OpWiseCLinker
......
...@@ -7,7 +7,7 @@ from theano.gof import fg, graph ...@@ -7,7 +7,7 @@ from theano.gof import fg, graph
from theano.gof.graph import Apply, Constant, Variable from theano.gof.graph import Apply, Constant, Variable
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
from theano.link import Container, PerformLinker, WrapLinker from theano.link.basic import Container, PerformLinker, WrapLinker
from theano.utils import cmp from theano.utils import cmp
......
import theano
from theano.compile import Mode from theano.compile import Mode
from theano.printing import hex_digest from theano.configdefaults import config
from theano.link.basic import WrapLinkerMany
from theano.link.c.vm import VM_Linker
from theano.printing import hex_digest, min_informative_str
__authors__ = ["PyMC Team", "Ian Goodfellow"] __authors__ = ["PyMC Team", "Ian Goodfellow"]
...@@ -199,7 +201,7 @@ class RecordMode(Mode): ...@@ -199,7 +201,7 @@ class RecordMode(Mode):
print(f"str(node):{node}") print(f"str(node):{node}")
print("Symbolic inputs: ") print("Symbolic inputs: ")
for elem in node.inputs: for elem in node.inputs:
print(theano.printing.min_informative_str(elem)) print(min_informative_str(elem))
print("str(output) of outputs: ") print("str(output) of outputs: ")
for elem in fn.outputs: for elem in fn.outputs:
assert isinstance(elem, list) assert isinstance(elem, list)
...@@ -248,7 +250,7 @@ class RecordMode(Mode): ...@@ -248,7 +250,7 @@ class RecordMode(Mode):
handle_line(fgraph, line, i, node, fn) handle_line(fgraph, line, i, node, fn)
# linker = theano.link.c.cc.OpWiseCLinker() # linker = theano.link.c.cc.OpWiseCLinker()
linker = theano.link.c.vm.VM_Linker(use_cloop=bool(theano.config.cxx)) linker = VM_Linker(use_cloop=bool(config.cxx))
wrap_linker = theano.link.WrapLinkerMany([linker], [callback]) wrap_linker = WrapLinkerMany([linker], [callback])
super().__init__(wrap_linker, optimizer="fast_run") super().__init__(wrap_linker, optimizer="fast_run")
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
# original author, and re-licensed under Theano's license. # original author, and re-licensed under Theano's license.
import numpy as np import numpy as np
import theano
from theano.compile.mode import Mode from theano.compile.mode import Mode
from theano.configdefaults import config
from theano.link.basic import WrapLinkerMany
from theano.link.c.cc import OpWiseCLinker
class MonitorMode(Mode): class MonitorMode(Mode):
...@@ -40,11 +42,9 @@ class MonitorMode(Mode): ...@@ -40,11 +42,9 @@ class MonitorMode(Mode):
def __init__(self, pre_func=None, post_func=None, optimizer="default", linker=None): def __init__(self, pre_func=None, post_func=None, optimizer="default", linker=None):
self.pre_func = pre_func self.pre_func = pre_func
self.post_func = post_func self.post_func = post_func
wrap_linker = theano.link.WrapLinkerMany( wrap_linker = WrapLinkerMany([OpWiseCLinker()], [self.eval])
[theano.link.c.cc.OpWiseCLinker()], [self.eval]
)
if optimizer == "default": if optimizer == "default":
optimizer = theano.config.optimizer optimizer = config.optimizer
if linker is not None and not isinstance(linker.mode, MonitorMode): if linker is not None and not isinstance(linker.mode, MonitorMode):
raise Exception( raise Exception(
"MonitorMode can only use its own linker! You " "MonitorMode can only use its own linker! You "
...@@ -95,13 +95,15 @@ class MonitorMode(Mode): ...@@ -95,13 +95,15 @@ class MonitorMode(Mode):
def detect_nan(fgraph, i, node, fn): def detect_nan(fgraph, i, node, fn):
from theano.printing import debugprint
for output in fn.outputs: for output in fn.outputs:
if ( if (
not isinstance(output[0], np.random.RandomState) not isinstance(output[0], np.random.RandomState)
and np.isnan(output[0]).any() and np.isnan(output[0]).any()
): ):
print("*** NaN detected ***") print("*** NaN detected ***")
theano.printing.debugprint(node) debugprint(node)
print("Inputs : %s" % [input[0] for input in fn.inputs]) print("Inputs : %s" % [input[0] for input in fn.inputs])
print("Outputs: %s" % [output[0] for output in fn.outputs]) print("Outputs: %s" % [output[0] for output in fn.outputs])
break break
...@@ -61,7 +61,8 @@ cimport numpy ...@@ -61,7 +61,8 @@ cimport numpy
import copy import copy
import time import time
from theano import gof, link from theano import gof
from theano.link.utils import raise_with_op
def get_version(): def get_version():
...@@ -405,7 +406,7 @@ def perform( ...@@ -405,7 +406,7 @@ def perform(
# done by raise_with_op is not implemented in C. # done by raise_with_op is not implemented in C.
if hasattr(fn, 'thunks'): if hasattr(fn, 'thunks'):
# For the CVM # For the CVM
link.raise_with_op(fn.maker.fgraph, raise_with_op(fn.maker.fgraph,
fn.nodes[fn.position_of_error], fn.nodes[fn.position_of_error],
fn.thunks[fn.position_of_error]) fn.thunks[fn.position_of_error])
else: else:
...@@ -413,7 +414,7 @@ def perform( ...@@ -413,7 +414,7 @@ def perform(
# We don't have access from python to all the # We don't have access from python to all the
# temps values So for now, we just don't print # temps values So for now, we just don't print
# the extra shapes/strides info # the extra shapes/strides info
link.raise_with_op(fn.maker.fgraph, fn.nodes[fn.position_of_error]) raise_with_op(fn.maker.fgraph, fn.nodes[fn.position_of_error])
else: else:
# old-style linkers raise their own exceptions # old-style linkers raise their own exceptions
raise raise
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论