提交 f2eca59f authored 作者: amrithasuresh's avatar amrithasuresh

Updated numpy as np

上级 7afbad40
...@@ -20,7 +20,7 @@ import logging ...@@ -20,7 +20,7 @@ import logging
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
import numpy import numpy as np
import theano import theano
from theano.compat import izip from theano.compat import izip
...@@ -589,8 +589,8 @@ def get_updates_and_outputs(ls): ...@@ -589,8 +589,8 @@ def get_updates_and_outputs(ls):
def isNaN_or_Inf_or_None(x): def isNaN_or_Inf_or_None(x):
isNone = x is None isNone = x is None
try: try:
isNaN = numpy.isnan(x) isNaN = np.isnan(x)
isInf = numpy.isinf(x) isInf = np.isinf(x)
isStr = isinstance(x, string_types) isStr = isinstance(x, string_types)
except Exception: except Exception:
isNaN = False isNaN = False
...@@ -599,8 +599,8 @@ def isNaN_or_Inf_or_None(x): ...@@ -599,8 +599,8 @@ def isNaN_or_Inf_or_None(x):
if not isNaN and not isInf: if not isNaN and not isInf:
try: try:
val = get_scalar_constant_value(x) val = get_scalar_constant_value(x)
isInf = numpy.isinf(val) isInf = np.isinf(val)
isNaN = numpy.isnan(val) isNaN = np.isnan(val)
except Exception: except Exception:
isNaN = False isNaN = False
isInf = False isInf = False
...@@ -959,7 +959,7 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -959,7 +959,7 @@ def scan_can_remove_outs(op, out_idxs):
added = False added = False
for pos, idx in enumerate(out_idxs): for pos, idx in enumerate(out_idxs):
if (out_idxs_mask[pos] and if (out_idxs_mask[pos] and
numpy.any([x in required_inputs for x in out_ins[idx]])): np.any([x in required_inputs for x in out_ins[idx]])):
# This output is required .. # This output is required ..
out_idxs_mask[pos] = 0 out_idxs_mask[pos] = 0
required_inputs += gof.graph.inputs([op.outputs[idx]]) required_inputs += gof.graph.inputs([op.outputs[idx]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论