提交 374d9b28 authored 作者: amrithasuresh's avatar amrithasuresh

Updated numpy as np

上级 2730024a
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import numpy as np import numpy as np
import numpy
from six.moves import xrange from six.moves import xrange
import theano import theano
...@@ -778,7 +777,7 @@ def repeat(x, repeats, axis=None): ...@@ -778,7 +777,7 @@ def repeat(x, repeats, axis=None):
shape[axis] = shape[axis] * repeats shape[axis] = shape[axis] * repeats
# dims_ is the dimension of that intermediate tensor. # dims_ is the dimension of that intermediate tensor.
dims_ = list(numpy.arange(x.ndim)) dims_ = list(np.arange(x.ndim))
dims_.insert(axis + 1, 'x') dims_.insert(axis + 1, 'x')
# After the original tensor is duplicated along the additional # After the original tensor is duplicated along the additional
...@@ -806,7 +805,7 @@ class Bartlett(gof.Op): ...@@ -806,7 +805,7 @@ class Bartlett(gof.Op):
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
M = inputs[0] M = inputs[0]
out, = out_ out, = out_
out[0] = numpy.bartlett(M) out[0] = np.bartlett(M)
def infer_shape(self, node, in_shapes): def infer_shape(self, node, in_shapes):
temp = node.inputs[0] temp = node.inputs[0]
...@@ -882,7 +881,7 @@ class FillDiagonal(gof.Op): ...@@ -882,7 +881,7 @@ class FillDiagonal(gof.Op):
# Write the value out into the diagonal. # Write the value out into the diagonal.
a.flat[:end:step] = val a.flat[:end:step] = val
else: else:
numpy.fill_diagonal(a, val) np.fill_diagonal(a, val)
output_storage[0][0] = a output_storage[0][0] = a
...@@ -1132,7 +1131,7 @@ class Unique(theano.Op): ...@@ -1132,7 +1131,7 @@ class Unique(theano.Op):
self.return_index = return_index self.return_index = return_index
self.return_inverse = return_inverse self.return_inverse = return_inverse
self.return_counts = return_counts self.return_counts = return_counts
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]] numpy_ver = [int(n) for n in np.__version__.split('.')[:2]]
if self.return_counts and bool(numpy_ver < [1, 9]): if self.return_counts and bool(numpy_ver < [1, 9]):
raise RuntimeError( raise RuntimeError(
"Numpy version = " + np.__version__ + "Numpy version = " + np.__version__ +
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论