提交 6dcf210b authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merged

......@@ -155,11 +155,11 @@ class omega_op(gof.PythonOp):
def scalar_switch(x, y, normal_f, scalar_f):
x, y = wrap(x), wrap(y)
if x.constant and not x.data.shape:
return scalar_f(y, x)
if y.constant and not y.data.shape:
return scalar_f(x, y)
# x, y = wrap(x), wrap(y)
# if x.constant and not x.data.shape:
# return scalar_f(y, x)
# if y.constant and not y.data.shape:
# return scalar_f(x, y)
return normal_f(x, y)
......@@ -231,13 +231,13 @@ class iadd(proto_add, inplace):
class add_scalar(omega_op):
impl = numpy.ndarray.__add__
class iadd_scalar(omega_op):
class iadd_scalar(omega_op, inplace):
impl = numpy.ndarray.__iadd__
class proto_twice(omega_op):
def grad(x, gz):
return scal(gz, 2.0)
return scale(gz, 2.0)
class twice(proto_twice):
def impl(x):
......
......@@ -133,12 +133,10 @@ import grad
def dataset_1hot(x, targ, n):
"""Return an looping iterator over 1-hot vectors
This function is a generator for the integers range(n) that works by
side-effect on the numpy ndarray mat.
On each iteration, mat is set (in-place) to the next element of an infinite
sequence of 1-hot vectors.
"""
assert targ.size == 1
......@@ -197,15 +195,6 @@ print w.data
# # 1 = mul(mul(neg(scal(mul(sub(0.736213102665, sigmoid(*3)), 1.0), 2.0)), sigmoid(*3)), sub(1, sigmoid(*3)))
# # 2 = transpose(0.11474051836)
# # 3 = dot(*2, *5)
# # 4 = dot(0.11474051836, 0.736213102665)
# # 5 = sigmoid(*4)
# # add(transpose(dot(*1, transpose(*5))), dot(mul(mul(dot(transpose(*2), *1), sigmoid(*4)), sub(1, sigmoid(*4))), transpose(0.736213102665)))
############################
......@@ -225,8 +214,31 @@ print w.data
############################
print core.ones((2, 2)) + 1
# print core.ones((2, 2)) + 1
# print numpy.ones((2, 2)) ** numpy.ones((2, 2))
############################
x = core.ones((2, 2))
y = core.zeros((1, 1))
#print "?", gof.graph.ops([], [x + y])
print x
x + x
print "1", gof.eval_env#.ops()
y + y
print "2", gof.eval_env#.ops()
x + x
print "3", gof.eval_env#.ops()
print numpy.ones((2, 2)) ** numpy.ones((2, 2))
x += (x + x)
print x
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论