Unverified 提交 e7e92dce authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #58 from brandonwillard/jax-import-quick-fix

Quick fix for import issues introduced by #21
......@@ -596,16 +596,15 @@ AddConfigVar(
# Also, please be careful not to modify the first item in the enum when adding
# new modes, since it is the default mode.
def filter_mode(val):
if (
val
in [
"Mode",
"DebugMode",
"NanGuardMode",
"DEBUG_MODE",
]
or val in theano.compile.mode.predefined_modes
):
if val in [
"Mode",
"DebugMode",
"FAST_RUN",
"NanGuardMode",
"FAST_COMPILE",
"DEBUG_MODE",
"JAX",
]:
return val
# This can be executed before Theano is completly imported, so
# theano.Mode is not always available.
......
......@@ -35,13 +35,7 @@ from theano.tensor.basic import (
Reshape,
Join,
)
from theano.scalar.basic import (
ScalarOp,
Composite,
Cast,
Clip,
Identity
)
from theano.scalar.basic import ScalarOp, Composite, Cast, Clip, Identity
from theano.tensor.elemwise import Elemwise, CAReduce, DimShuffle
from theano.compile.ops import (
DeepCopyOp,
......
......@@ -6032,7 +6032,7 @@ ALL_REDUCE = [
T.elemwise.Sum,
T.elemwise.Prod,
T.elemwise.ProdWithoutZeros,
]
] + T.elemwise.CAReduce.__subclasses__()
@register_canonicalize
......
......@@ -33,12 +33,12 @@ supposed to be canonical.
import logging
import theano.tensor.basic as tt
import theano.scalar.basic as scal
from theano.tensor.elemwise import CAReduce
from theano.tensor import basic as tt
from theano import scalar as scal
from theano.tensor import DimShuffle, Subtensor
from theano.gof.opt import copy_stack_trace, local_optimizer
from theano.tensor.subtensor import Subtensor
from theano.tensor.elemwise import CAReduce, DimShuffle
from theano.tensor.opt import register_uncanonicalize
_logger = logging.getLogger("theano.tensor.opt")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论