提交 37764e73 authored 作者: kvmanohar22's avatar kvmanohar22

changed numpy imports to one common form

上级 6a888162
#!/usr/bin/env python
from __future__ import absolute_import, print_function, division
import numpy as N
import numpy as np
import sys
import time
from six.moves import xrange
......@@ -10,7 +10,7 @@ from six.moves import xrange
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
lr = 0.01
rng = N.random.RandomState(342)
rng = np.random.RandomState(342)
w = rng.rand(nout, nhid)
a = rng.randn(nhid) * 0.0
......@@ -22,38 +22,38 @@ dot_time = 0.0
t = time.time()
for i in xrange(niter):
tt = time.time()
d = N.dot(x, w)
d = np.dot(x, w)
dot_time += time.time() - tt
hid = N.tanh(d + a)
hid = np.tanh(d + a)
tt = time.time()
d = N.dot(hid, w.T)
d = np.dot(hid, w.T)
dot_time += time.time() - tt
out = N.tanh(d + b)
out = np.tanh(d + b)
g_out = out - x
err = 0.5 * N.sum(g_out**2)
err = 0.5 * np.sum(g_out**2)
g_hidwt = g_out * (1.0 - out**2)
b -= lr * N.sum(g_hidwt, axis=0)
b -= lr * np.sum(g_hidwt, axis=0)
tt = time.time()
g_hid = N.dot(g_hidwt, w)
g_hid = np.dot(g_hidwt, w)
dot_time += time.time() - tt
g_hidin = g_hid * (1.0 - hid**2)
tt = time.time()
d = N.dot(g_hidwt.T, hid)
dd = N.dot(x.T, g_hidin)
d = np.dot(g_hidwt.T, hid)
dd = np.dot(x.T, g_hidin)
dot_time += time.time() - tt
gw = (d + dd)
w -= lr * gw
a -= lr * N.sum(g_hidin, axis=0)
a -= lr * np.sum(g_hidin, axis=0)
total_time = time.time() - t
print('time: ',total_time, 'err: ', err)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论