Unverified 提交 e7e92dce authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #58 from brandonwillard/jax-import-quick-fix

Quick fix for import issues introduced by #21
...@@ -596,16 +596,15 @@ AddConfigVar( ...@@ -596,16 +596,15 @@ AddConfigVar(
# Also, please be careful not to modify the first item in the enum when adding # Also, please be careful not to modify the first item in the enum when adding
# new modes, since it is the default mode. # new modes, since it is the default mode.
def filter_mode(val): def filter_mode(val):
if ( if val in [
val "Mode",
in [ "DebugMode",
"Mode", "FAST_RUN",
"DebugMode", "NanGuardMode",
"NanGuardMode", "FAST_COMPILE",
"DEBUG_MODE", "DEBUG_MODE",
] "JAX",
or val in theano.compile.mode.predefined_modes ]:
):
return val return val
# This can be executed before Theano is completly imported, so # This can be executed before Theano is completly imported, so
# theano.Mode is not always available. # theano.Mode is not always available.
......
...@@ -35,13 +35,7 @@ from theano.tensor.basic import ( ...@@ -35,13 +35,7 @@ from theano.tensor.basic import (
Reshape, Reshape,
Join, Join,
) )
from theano.scalar.basic import ( from theano.scalar.basic import ScalarOp, Composite, Cast, Clip, Identity
ScalarOp,
Composite,
Cast,
Clip,
Identity
)
from theano.tensor.elemwise import Elemwise, CAReduce, DimShuffle from theano.tensor.elemwise import Elemwise, CAReduce, DimShuffle
from theano.compile.ops import ( from theano.compile.ops import (
DeepCopyOp, DeepCopyOp,
......
...@@ -6032,7 +6032,7 @@ ALL_REDUCE = [ ...@@ -6032,7 +6032,7 @@ ALL_REDUCE = [
T.elemwise.Sum, T.elemwise.Sum,
T.elemwise.Prod, T.elemwise.Prod,
T.elemwise.ProdWithoutZeros, T.elemwise.ProdWithoutZeros,
] ] + T.elemwise.CAReduce.__subclasses__()
@register_canonicalize @register_canonicalize
......
...@@ -33,12 +33,12 @@ supposed to be canonical. ...@@ -33,12 +33,12 @@ supposed to be canonical.
import logging import logging
import theano.tensor.basic as tt from theano.tensor.elemwise import CAReduce
import theano.scalar.basic as scal from theano.tensor import basic as tt
from theano import scalar as scal
from theano.tensor import DimShuffle, Subtensor
from theano.gof.opt import copy_stack_trace, local_optimizer from theano.gof.opt import copy_stack_trace, local_optimizer
from theano.tensor.subtensor import Subtensor
from theano.tensor.elemwise import CAReduce, DimShuffle
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
_logger = logging.getLogger("theano.tensor.opt") _logger = logging.getLogger("theano.tensor.opt")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论