Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6a168155
提交
6a168155
authored
7月 24, 2012
作者:
nouiz
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #781 from nouiz/pep8
Pep8
上级
96ea23b6
9092e850
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
302 行增加
和
161 行删除
+302
-161
nnet.py
theano/tensor/nnet/nnet.py
+302
-161
没有找到文件。
theano/tensor/nnet/nnet.py
浏览文件 @
6a168155
...
@@ -26,7 +26,8 @@ class SoftmaxWithBias(gof.Op):
...
@@ -26,7 +26,8 @@ class SoftmaxWithBias(gof.Op):
An L{Op} for the output of neural-net multiclass classifiers.
An L{Op} for the output of neural-net multiclass classifiers.
@type x: is a matrix of floats (32 or 64)
@type x: is a matrix of floats (32 or 64)
@type b: is a [row] vector of floats (32 or 64), length is number of cols in x
@type b: is a [row] vector of floats (32 or 64),
length is number of cols in x
This L{Op}'s output is softmax(x+b).
This L{Op}'s output is softmax(x+b).
softmax(x[i]) is the i'th distribution over len(x[i]) options.
softmax(x[i]) is the i'th distribution over len(x[i]) options.
...
@@ -34,13 +35,16 @@ class SoftmaxWithBias(gof.Op):
...
@@ -34,13 +35,16 @@ class SoftmaxWithBias(gof.Op):
nin
=
2
nin
=
2
nout
=
1
nout
=
1
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
gof
.
Op
.
__init__
(
self
,
**
kwargs
)
gof
.
Op
.
__init__
(
self
,
**
kwargs
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
return
tensor
.
hashtype
(
self
)
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
...
@@ -74,14 +78,14 @@ class SoftmaxWithBias(gof.Op):
...
@@ -74,14 +78,14 @@ class SoftmaxWithBias(gof.Op):
g_sm
,
=
grads
g_sm
,
=
grads
sm
=
softmax_with_bias
(
x
,
b
)
sm
=
softmax_with_bias
(
x
,
b
)
dx
=
softmax_grad
(
g_sm
,
sm
)
dx
=
softmax_grad
(
g_sm
,
sm
)
db
=
tensor
.
sum
(
dx
,
axis
=
0
)
db
=
tensor
.
sum
(
dx
,
axis
=
0
)
return
dx
,
db
return
dx
,
db
def
infer_shape
(
self
,
node
,
shape
):
def
infer_shape
(
self
,
node
,
shape
):
return
[
shape
[
0
]]
return
[
shape
[
0
]]
def
c_headers
(
self
):
def
c_headers
(
self
):
return
[
'<iostream>'
,
'<cmath>'
]
return
[
'<iostream>'
,
'<cmath>'
]
@staticmethod
@staticmethod
def
c_code_template
():
def
c_code_template
():
...
@@ -107,19 +111,22 @@ class SoftmaxWithBias(gof.Op):
...
@@ -107,19 +111,22 @@ class SoftmaxWithBias(gof.Op):
PyErr_SetString(PyExc_ValueError, "b not 1d tensor");
PyErr_SetString(PyExc_ValueError, "b not 1d tensor");
%(fail)
s;
%(fail)
s;
}
}
if ((
%(x)
s->descr->type_num != PyArray_DOUBLE)&&(
%(x)
s->descr->type_num != PyArray_FLOAT))
if ((
%(x)
s->descr->type_num != PyArray_DOUBLE) &&
(
%(x)
s->descr->type_num != PyArray_FLOAT))
{
{
PyErr_SetString(PyExc_TypeError, "a not float");
PyErr_SetString(PyExc_TypeError, "a not float");
%(fail)
s;
%(fail)
s;
}
}
if ((
%(b)
s->descr->type_num != PyArray_DOUBLE) && (
%(b)
s->descr->type_num != PyArray_FLOAT))
if ((
%(b)
s->descr->type_num != PyArray_DOUBLE) &&
(
%(b)
s->descr->type_num != PyArray_FLOAT))
{
{
PyErr_SetString(PyExc_TypeError, "b not float");
PyErr_SetString(PyExc_TypeError, "b not float");
%(fail)
s;
%(fail)
s;
}
}
if ((
%(x)
s->dimensions[1] !=
%(b)
s->dimensions[0]))
if ((
%(x)
s->dimensions[1] !=
%(b)
s->dimensions[0]))
{
{
PyErr_Format(PyExc_ValueError, "number of columns in x (
%%
ld) does not match length of b (
%%
ld)",
PyErr_Format(PyExc_ValueError,
"number of columns in x (
%%
ld) does not match length of b (
%%
ld)",
(long int)
%(x)
s->dimensions[1], (long int)
%(b)
s->dimensions[0]);
(long int)
%(x)
s->dimensions[1], (long int)
%(b)
s->dimensions[0]);
%(fail)
s;
%(fail)
s;
}
}
...
@@ -129,9 +136,11 @@ class SoftmaxWithBias(gof.Op):
...
@@ -129,9 +136,11 @@ class SoftmaxWithBias(gof.Op):
|| (
%(sm)
s->dimensions[1] !=
%(x)
s->dimensions[1]))
|| (
%(sm)
s->dimensions[1] !=
%(x)
s->dimensions[1]))
{
{
if (NULL !=
%(sm)
s) Py_XDECREF(
%(sm)
s);
if (NULL !=
%(sm)
s) Py_XDECREF(
%(sm)
s);
%(sm)
s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(
%(x)
s), type_num_
%(x)
s);
%(sm)
s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(
%(x)
s),
type_num_
%(x)
s);
if(!
%(sm)
s) {
if(!
%(sm)
s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc sm output");
PyErr_SetString(PyExc_MemoryError,
"failed to alloc sm output");
%(fail)
s
%(fail)
s
}
}
}
}
...
@@ -146,7 +155,7 @@ class SoftmaxWithBias(gof.Op):
...
@@ -146,7 +155,7 @@ class SoftmaxWithBias(gof.Op):
const dtype_
%(x)
s* __restrict__ x_i = (dtype_
%(x)
s*)(
%(x)
s->data +
%(x)
s->strides[0] * i);
const dtype_
%(x)
s* __restrict__ x_i = (dtype_
%(x)
s*)(
%(x)
s->data +
%(x)
s->strides[0] * i);
const dtype_
%(b)
s* __restrict__ b_i = (dtype_
%(b)
s*)(
%(b)
s->data);
const dtype_
%(b)
s* __restrict__ b_i = (dtype_
%(b)
s*)(
%(b)
s->data);
dtype_
%(sm)
s* __restrict__ sm_i = (dtype_
%(sm)
s*)(
%(sm)
s->data +
%(sm)
s->strides[0] * i);
dtype_
%(sm)
s* __restrict__ sm_i = (dtype_
%(sm)
s*)(
%(sm)
s->data +
%(sm)
s->strides[0] * i);
"""
"""
inside_row_loop
=
"""
inside_row_loop
=
"""
...
@@ -191,7 +200,6 @@ class SoftmaxWithBias(gof.Op):
...
@@ -191,7 +200,6 @@ class SoftmaxWithBias(gof.Op):
return
(
init_decl
,
begin_row_loop
,
inside_row_loop
,
end_row_loop
)
return
(
init_decl
,
begin_row_loop
,
inside_row_loop
,
end_row_loop
)
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
x
,
b
=
inp
x
,
b
=
inp
sm
,
=
out
sm
,
=
out
...
@@ -205,7 +213,6 @@ class SoftmaxWithBias(gof.Op):
...
@@ -205,7 +213,6 @@ class SoftmaxWithBias(gof.Op):
softmax_with_bias
=
SoftmaxWithBias
()
softmax_with_bias
=
SoftmaxWithBias
()
class
SoftmaxGrad
(
gof
.
Op
):
class
SoftmaxGrad
(
gof
.
Op
):
"""Gradient wrt x of the Softmax Op"""
"""Gradient wrt x of the Softmax Op"""
nin
=
2
nin
=
2
...
@@ -245,18 +252,23 @@ class SoftmaxGrad(gof.Op):
...
@@ -245,18 +252,23 @@ class SoftmaxGrad(gof.Op):
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
3
,)
return
(
3
,)
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
dy
,
sm
=
inp
dy
,
sm
=
inp
dx
,
=
out
dx
,
=
out
return
'''
return
'''
if ((
%(dy)
s->descr->type_num != PyArray_DOUBLE) && (
%(dy)
s->descr->type_num != PyArray_FLOAT))
if ((
%(dy)
s->descr->type_num != PyArray_DOUBLE) &&
(
%(dy)
s->descr->type_num != PyArray_FLOAT))
{
{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
PyErr_SetString(PyExc_TypeError,
"types should be float or float64");
%(fail)
s;
%(fail)
s;
}
}
if ((
%(sm)
s->descr->type_num != PyArray_DOUBLE) && (
%(sm)
s->descr->type_num != PyArray_FLOAT))
if ((
%(sm)
s->descr->type_num != PyArray_DOUBLE) &&
(
%(sm)
s->descr->type_num != PyArray_FLOAT))
{
{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
PyErr_SetString(PyExc_TypeError,
"types should be float or float64");
%(fail)
s;
%(fail)
s;
}
}
if ((
%(dy)
s->nd != 2)
if ((
%(dy)
s->nd != 2)
...
@@ -275,11 +287,13 @@ class SoftmaxGrad(gof.Op):
...
@@ -275,11 +287,13 @@ class SoftmaxGrad(gof.Op):
|| (
%(dx)
s->dimensions[1] !=
%(sm)
s->dimensions[1]))
|| (
%(dx)
s->dimensions[1] !=
%(sm)
s->dimensions[1]))
{
{
Py_XDECREF(
%(dx)
s);
Py_XDECREF(
%(dx)
s);
%(dx)
s = (PyArrayObject*) PyArray_SimpleNew(2, PyArray_DIMS(
%(sm)
s),
%(dx)
s = (PyArrayObject*) PyArray_SimpleNew(2,
PyArray_DIMS(
%(sm)
s),
type_num_
%(sm)
s);
type_num_
%(sm)
s);
if (!
%(dx)
s)
if (!
%(dx)
s)
{
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc dx output");
PyErr_SetString(PyExc_MemoryError,
"failed to alloc dx output");
%(fail)
s;
%(fail)
s;
}
}
}
}
...
@@ -290,7 +304,7 @@ class SoftmaxGrad(gof.Op):
...
@@ -290,7 +304,7 @@ class SoftmaxGrad(gof.Op):
npy_intp Sdy =
%(dy)
s->strides[1]/sizeof(dtype_
%(dy)
s);
npy_intp Sdy =
%(dy)
s->strides[1]/sizeof(dtype_
%(dy)
s);
const dtype_
%(sm)
s* __restrict__ sm_i = (dtype_
%(sm)
s*) (
%(sm)
s->data +
%(sm)
s->strides[0] * i);
const dtype_
%(sm)
s* __restrict__ sm_i = (dtype_
%(sm)
s*) (
%(sm)
s->data +
%(sm)
s->strides[0] * i);
npy_intp Ssm =
%(sm)
s->strides[1]/sizeof(dtype_
%(sm)
s);
npy_intp Ssm =
%(sm)
s->strides[1]/sizeof(dtype_
%(sm)
s);
dtype_
%(dx)
s* __restrict__ dx_i = (dtype_
%(dx)
s*) (
%(dx)
s->data +
%(dx)
s->strides[0] * i);
dtype_
%(dx)
s* __restrict__ dx_i = (dtype_
%(dx)
s*) (
%(dx)
s->data +
%(dx)
s->strides[0] * i);
npy_intp Sdx =
%(dx)
s->strides[1]/sizeof(dtype_
%(dx)
s);
npy_intp Sdx =
%(dx)
s->strides[1]/sizeof(dtype_
%(dx)
s);
double sum_dy_times_sm = 0.;
double sum_dy_times_sm = 0.;
...
@@ -307,6 +321,7 @@ class SoftmaxGrad(gof.Op):
...
@@ -307,6 +321,7 @@ class SoftmaxGrad(gof.Op):
'''
%
dict
(
locals
(),
**
sub
)
'''
%
dict
(
locals
(),
**
sub
)
softmax_grad
=
SoftmaxGrad
()
softmax_grad
=
SoftmaxGrad
()
class
Softmax
(
gof
.
Op
):
class
Softmax
(
gof
.
Op
):
"""
"""
WRITEME
WRITEME
...
@@ -314,12 +329,16 @@ class Softmax(gof.Op):
...
@@ -314,12 +329,16 @@ class Softmax(gof.Op):
nin
=
1
nin
=
1
nout
=
1
nout
=
1
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
gof
.
Op
.
__init__
(
self
,
**
kwargs
)
gof
.
Op
.
__init__
(
self
,
**
kwargs
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
...
@@ -359,6 +378,7 @@ class Softmax(gof.Op):
...
@@ -359,6 +378,7 @@ class Softmax(gof.Op):
softmax
=
Softmax
()
softmax
=
Softmax
()
@opt.register_specialize
@opt.register_specialize
@gof.local_optimizer
([
softmax
])
@gof.local_optimizer
([
softmax
])
def
local_softmax_with_bias
(
node
):
def
local_softmax_with_bias
(
node
):
...
@@ -371,35 +391,38 @@ def local_softmax_with_bias(node):
...
@@ -371,35 +391,38 @@ def local_softmax_with_bias(node):
non_vectors
=
[]
non_vectors
=
[]
for
x_in
in
x
.
owner
.
inputs
:
for
x_in
in
x
.
owner
.
inputs
:
if
list
(
x_in
.
type
.
broadcastable
)
==
[
True
,
False
]:
if
list
(
x_in
.
type
.
broadcastable
)
==
[
True
,
False
]:
# print isinstance(x_in.owner.op, tensor.DimShuffle)
# print isinstance(x_in.owner.op,
#since specialization comes relatively late in optimization,
#tensor.DimShuffle) since specialization comes
# we don't want to put in extra DimShuffles un-necessarily.
#relatively late in optimization, we don't want to
if
x_in
.
owner
and
isinstance
(
x_in
.
owner
.
op
,
tensor
.
DimShuffle
)
\
#put in extra DimShuffles un-necessarily.
and
list
(
x_in
.
owner
.
inputs
[
0
]
.
type
.
broadcastable
)
==
[
False
]:
if
(
x_in
.
owner
and
isinstance
(
x_in
.
owner
.
op
,
tensor
.
DimShuffle
)
and
list
(
x_in
.
owner
.
inputs
[
0
]
.
type
.
broadcastable
)
==
[
False
]):
# cut out the DimShuffle that was broadcasting a vector
# cut out the DimShuffle that was broadcasting a vector
vectors
.
append
(
x_in
.
owner
.
inputs
[
0
])
vectors
.
append
(
x_in
.
owner
.
inputs
[
0
])
else
:
else
:
# insert an extra DimShuffle to correct the old one
# insert an extra DimShuffle to correct the old one
vectors
.
append
(
tensor
.
DimShuffle
((
True
,
False
),
(
1
,))(
x_in
))
vectors
.
append
(
tensor
.
DimShuffle
((
True
,
False
),
(
1
,))(
x_in
))
else
:
else
:
non_vectors
.
append
(
x_in
)
non_vectors
.
append
(
x_in
)
# If all the inputs were vectors or broadcasted vectors,
# If all the inputs were vectors or broadcasted vectors,
# we broadcast one of them to be used as a matrix
# we broadcast one of them to be used as a matrix
if
len
(
non_vectors
)
==
0
:
if
len
(
non_vectors
)
==
0
:
assert
len
(
vectors
)
>
0
# we should have at least 1 input...
assert
len
(
vectors
)
>
0
# we should have at least 1 input...
promoted_vector
=
vectors
.
pop
()
promoted_vector
=
vectors
.
pop
()
non_vectors
.
append
(
tensor
.
shape_padleft
(
promoted_vector
))
non_vectors
.
append
(
tensor
.
shape_padleft
(
promoted_vector
))
assert
non_vectors
#
not empty
assert
non_vectors
#
not empty
if
vectors
:
if
vectors
:
#we're in business...
#we're in business...
if
len
(
vectors
)
>
1
:
if
len
(
vectors
)
>
1
:
vector_sum
=
tensor
.
add
(
*
vectors
)
vector_sum
=
tensor
.
add
(
*
vectors
)
else
:
else
:
vector_sum
=
vectors
[
0
]
vector_sum
=
vectors
[
0
]
if
len
(
non_vectors
)
>
1
:
if
len
(
non_vectors
)
>
1
:
non_vector_sum
=
tensor
.
add
(
*
non_vectors
)
non_vector_sum
=
tensor
.
add
(
*
non_vectors
)
else
:
else
:
non_vector_sum
=
non_vectors
[
0
]
non_vector_sum
=
non_vectors
[
0
]
...
@@ -407,7 +430,8 @@ def local_softmax_with_bias(node):
...
@@ -407,7 +430,8 @@ def local_softmax_with_bias(node):
try
:
try
:
sm_bias
=
softmax_with_bias
(
non_vector_sum
,
vector_sum
)
sm_bias
=
softmax_with_bias
(
non_vector_sum
,
vector_sum
)
except
Exception
:
except
Exception
:
#if our arguments have the wrong types, then forget about it
#if our arguments have the wrong types, then
#forget about it
return
return
if
sm_bias
.
type
==
node
.
outputs
[
0
]
.
type
:
if
sm_bias
.
type
==
node
.
outputs
[
0
]
.
type
:
...
@@ -415,6 +439,7 @@ def local_softmax_with_bias(node):
...
@@ -415,6 +439,7 @@ def local_softmax_with_bias(node):
#nnet/tests/test_nnet.py:T_SoftmaxWithBias.test_broadcast
#nnet/tests/test_nnet.py:T_SoftmaxWithBias.test_broadcast
return
[
sm_bias
]
return
[
sm_bias
]
def
softmax_simplifier
(
numerators
,
denominators
):
def
softmax_simplifier
(
numerators
,
denominators
):
for
numerator
in
list
(
numerators
):
for
numerator
in
list
(
numerators
):
#TODO: a single softmax'd vector??
#TODO: a single softmax'd vector??
...
@@ -431,9 +456,11 @@ def softmax_simplifier(numerators, denominators):
...
@@ -431,9 +456,11 @@ def softmax_simplifier(numerators, denominators):
matching_denom
=
None
matching_denom
=
None
for
denominator
in
denominators
:
for
denominator
in
denominators
:
if
denominator
.
owner
and
isinstance
(
denominator
.
owner
.
op
,
tensor
.
DimShuffle
):
if
denominator
.
owner
and
isinstance
(
denominator
.
owner
.
op
,
if
denominator
.
owner
.
op
.
new_order
==
(
0
,
'x'
):
tensor
.
DimShuffle
):
z
=
denominator
.
owner
.
inputs
[
0
]
# thing getting dimshuffled
if
denominator
.
owner
.
op
.
new_order
==
(
0
,
'x'
):
z
=
denominator
.
owner
.
inputs
[
0
]
# thing getting dimshuffled
if
z
.
owner
and
isinstance
(
z
.
owner
.
op
,
tensor
.
Sum
):
if
z
.
owner
and
isinstance
(
z
.
owner
.
op
,
tensor
.
Sum
):
#print 'ASDF', denominator.owner.op.new_order
#print 'ASDF', denominator.owner.op.new_order
#print z.owner.op.axis
#print z.owner.op.axis
...
@@ -447,7 +474,8 @@ def softmax_simplifier(numerators, denominators):
...
@@ -447,7 +474,8 @@ def softmax_simplifier(numerators, denominators):
denominators
.
remove
(
matching_denom
)
denominators
.
remove
(
matching_denom
)
numerators
.
append
(
softmax
(
x
))
numerators
.
append
(
softmax
(
x
))
return
numerators
,
denominators
return
numerators
,
denominators
opt
.
local_mul_canonizer
.
add_simplifier
(
softmax_simplifier
,
'softmax_simplifier'
)
opt
.
local_mul_canonizer
.
add_simplifier
(
softmax_simplifier
,
'softmax_simplifier'
)
if
0
:
if
0
:
@opt.register_specialize
@opt.register_specialize
...
@@ -457,7 +485,7 @@ if 0:
...
@@ -457,7 +485,7 @@ if 0:
#TODO what if the signs are changed?
#TODO what if the signs are changed?
#TODO and if a scalar is distributed before each of the terms?
#TODO and if a scalar is distributed before each of the terms?
#TODO 'dy' could also be a product
#TODO 'dy' could also be a product
if
node
.
op
==
tensor
.
add
and
node
.
out
.
ndim
==
2
:
if
node
.
op
==
tensor
.
add
and
node
.
out
.
ndim
==
2
:
add_inputs
=
node
.
inputs
add_inputs
=
node
.
inputs
# Trying to locate two nodes in the sum:
# Trying to locate two nodes in the sum:
# dy * sm, prod_term
# dy * sm, prod_term
...
@@ -466,9 +494,12 @@ if 0:
...
@@ -466,9 +494,12 @@ if 0:
other_terms
=
[]
other_terms
=
[]
# First, prod_term
# First, prod_term
for
add_in
in
add_inputs
:
for
add_in
in
add_inputs
:
if
add_in
.
owner
and
add_in
.
owner
.
op
==
tensor
.
mul
and
prod_term
is
None
:
if
(
add_in
.
owner
and
add_in
.
owner
.
op
==
tensor
.
mul
and
prod_term
is
None
):
mul_inputs
=
add_in
.
owner
.
inputs
mul_inputs
=
add_in
.
owner
.
inputs
if
len
(
mul_inputs
)
==
2
and
all
([
mul_in
.
ndim
==
2
for
mul_in
in
mul_inputs
]):
if
(
len
(
mul_inputs
)
==
2
and
all
([
mul_in
.
ndim
==
2
for
mul_in
in
mul_inputs
])):
prod_term
=
add_in
prod_term
=
add_in
else
:
else
:
other_terms
.
append
(
add_in
)
other_terms
.
append
(
add_in
)
...
@@ -477,7 +508,7 @@ if 0:
...
@@ -477,7 +508,7 @@ if 0:
if
prod_term
is
None
:
if
prod_term
is
None
:
#print 'no prod_term'
#print 'no prod_term'
return
return
assert
len
(
other_terms
)
==
len
(
add_inputs
)
-
1
assert
len
(
other_terms
)
==
len
(
add_inputs
)
-
1
ds_term
=
None
ds_term
=
None
rest
=
[]
rest
=
[]
...
@@ -493,10 +524,13 @@ if 0:
...
@@ -493,10 +524,13 @@ if 0:
# Try and find DimShuffle(Sum)
# Try and find DimShuffle(Sum)
maybe_ds
=
None
maybe_ds
=
None
for
i
,
mul2_in
in
enumerate
(
mul2_inputs
):
for
i
,
mul2_in
in
enumerate
(
mul2_inputs
):
if
mul2_in
.
owner
and
isinstance
(
mul2_in
.
owner
.
op
,
elemwise
.
DimShuffle
):
if
mul2_in
.
owner
and
isinstance
(
mul2_in
.
owner
.
op
,
elemwise
.
DimShuffle
):
maybe_ds
=
mul2_in
maybe_ds
=
mul2_in
maybe_sm
=
mul2_inputs
[
1
-
i
]
# The other one
maybe_sm
=
mul2_inputs
[
1
-
i
]
# The other one
if
maybe_ds
is
None
or
maybe_ds
.
ndim
!=
2
or
maybe_sm
.
ndim
!=
2
:
if
(
maybe_ds
is
None
or
maybe_ds
.
ndim
!=
2
or
maybe_sm
.
ndim
!=
2
):
rest
.
append
(
add_in
)
rest
.
append
(
add_in
)
#print 'maybe_ds =', maybe_ds
#print 'maybe_ds =', maybe_ds
#if maybe_ds:
#if maybe_ds:
...
@@ -516,11 +550,14 @@ if 0:
...
@@ -516,11 +550,14 @@ if 0:
ds_order
=
maybe_ds
.
owner
.
op
.
new_order
ds_order
=
maybe_ds
.
owner
.
op
.
new_order
ds_input
=
maybe_ds
.
owner
.
inputs
[
0
]
ds_input
=
maybe_ds
.
owner
.
inputs
[
0
]
axis
=
None
axis
=
None
if
ds_input
.
owner
and
isinstance
(
ds_input
.
owner
.
op
,
elemwise
.
Sum
):
if
ds_input
.
owner
and
isinstance
(
ds_input
.
owner
.
op
,
elemwise
.
Sum
):
axis
=
ds_input
.
owner
.
op
.
axis
axis
=
ds_input
.
owner
.
op
.
axis
sum_input
=
ds_input
.
owner
.
inputs
[
0
]
sum_input
=
ds_input
.
owner
.
inputs
[
0
]
if
(
ds_order
!=
(
0
,
'x'
))
or
(
axis
!=
(
1
,))
or
(
sum_input
is
not
prod_term
):
if
((
ds_order
!=
(
0
,
'x'
))
or
(
axis
!=
(
1
,))
or
(
sum_input
is
not
prod_term
)):
rest
.
append
(
add_in
)
rest
.
append
(
add_in
)
#print 'ds_order =', ds_order
#print 'ds_order =', ds_order
#print 'axis =', axis
#print 'axis =', axis
...
@@ -553,12 +590,15 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
...
@@ -553,12 +590,15 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
"""A special compound L{Op} for the output of neural-net classifiers.
"""A special compound L{Op} for the output of neural-net classifiers.
:type x: is a matrix of floats (32 or 64)
:type x: is a matrix of floats (32 or 64)
:type b: is a [row] vector of floats (32 or 64), length is number of cols in x
:type b: is a [row] vector of floats (32 or 64),
:type y_idx: a [column] vector of int (32 or 64), length is number of rows in x
length is number of cols in x
:type y_idx: a [column] vector of int (32 or 64),
length is number of rows in x
:returns: row-wise NLL, softmax(x+b), row-wise argmax of (x+b)
:returns: row-wise NLL, softmax(x+b), row-wise argmax of (x+b)
@precondition: every entry in y_idx is a valid (non-negative) column index into x
@precondition: every entry in y_idx is a valid (non-negative)
column index into x
This L{Op} has three outputs:
This L{Op} has three outputs:
- KL(softmax(x+b), y)
- KL(softmax(x+b), y)
...
@@ -574,16 +614,21 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
...
@@ -574,16 +614,21 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
i'th example.
i'th example.
"""
"""
nin
=
3
nin
=
3
nout
=
3
nout
=
3
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
gof
.
Op
.
__init__
(
self
,
**
kwargs
)
gof
.
Op
.
__init__
(
self
,
**
kwargs
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
return
tensor
.
hashtype
(
self
)
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
def
make_node
(
self
,
x
,
b
,
y_idx
):
def
make_node
(
self
,
x
,
b
,
y_idx
):
x
=
tensor
.
as_tensor_variable
(
x
)
x
=
tensor
.
as_tensor_variable
(
x
)
b
=
tensor
.
as_tensor_variable
(
b
)
b
=
tensor
.
as_tensor_variable
(
b
)
...
@@ -605,14 +650,15 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
...
@@ -605,14 +650,15 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
sm
=
x
.
type
.
make_variable
()
sm
=
x
.
type
.
make_variable
()
am
=
y_idx
.
type
.
make_variable
()
am
=
y_idx
.
type
.
make_variable
()
return
Apply
(
self
,
[
x
,
b
,
y_idx
],
[
nll
,
sm
,
am
])
return
Apply
(
self
,
[
x
,
b
,
y_idx
],
[
nll
,
sm
,
am
])
def
perform
(
self
,
node
,
input_storage
,
output_storage
):
def
perform
(
self
,
node
,
input_storage
,
output_storage
):
"""
"""The math, where x is an input vector, and t is a target index:
The math, where x is an input vector, and t is a target index:
softmax(x)[i] = exp(x[i]) / sum_j(exp(x[j]))
softmax(x)[i] = exp(x[i]) / sum_j(exp(x[j]))
nll(x,t) = -log(softmax(x)[t])
nll(x,t) = -log(softmax(x)[t])
We compute this by subtracting off the max of x. This avoids numerical instability.
We compute this by subtracting off the max of x. This avoids
numerical instability.
m = max_j x[j]
m = max_j x[j]
softmax(x)[i] = exp(x[i] -m) / sum_j(exp(x[j] - m))
softmax(x)[i] = exp(x[i] -m) / sum_j(exp(x[j] - m))
...
@@ -627,20 +673,22 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
...
@@ -627,20 +673,22 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
if
y_idx
.
shape
[
0
]
!=
x
.
shape
[
0
]:
if
y_idx
.
shape
[
0
]
!=
x
.
shape
[
0
]:
raise
ValueError
(
'y_idx must have same number of rows as x'
)
raise
ValueError
(
'y_idx must have same number of rows as x'
)
sm
=
numpy
.
zeros_like
(
x
)
# softmax
sm
=
numpy
.
zeros_like
(
x
)
# softmax
nll
=
numpy
.
zeros
(
x
.
shape
[
0
],
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
)
#nll(y | softmax(x))
nll
=
numpy
.
zeros
(
x
.
shape
[
0
],
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
)
# nll(y | softmax(x))
am
=
numpy
.
zeros_like
(
y_idx
)
am
=
numpy
.
zeros_like
(
y_idx
)
for
i
in
xrange
(
sm
.
shape
[
0
]):
for
i
in
xrange
(
sm
.
shape
[
0
]):
#add the bias vector to the i'th row of x
#add the bias vector to the i'th row of x
row
=
x
[
i
]
+
b
row
=
x
[
i
]
+
b
#get the maximum value of i'th row for numerically safe softmax / nll
#get the maximum value of i'th row for numerically safe
#softmax / nll
am
[
i
]
=
numpy
.
argmax
(
row
)
am
[
i
]
=
numpy
.
argmax
(
row
)
m
=
row
[
am
[
i
]]
m
=
row
[
am
[
i
]]
#compute the unnormalized softmax, and normalization constant
#compute the unnormalized softmax, and normalization constant
sm
[
i
]
=
numpy
.
exp
(
row
-
m
)
sm
[
i
]
=
numpy
.
exp
(
row
-
m
)
sum_j
=
numpy
.
sum
(
sm
[
i
])
# sum_j(exp(x[j] - m))
sum_j
=
numpy
.
sum
(
sm
[
i
])
# sum_j(exp(x[j] - m))
#normalized our softmax
#normalized our softmax
sm
[
i
]
*=
1.0
/
sum_j
sm
[
i
]
*=
1.0
/
sum_j
...
@@ -675,7 +723,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
...
@@ -675,7 +723,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
nll
,
sm
=
crossentropy_softmax_1hot_with_bias
(
x
,
b
,
y_idx
)
nll
,
sm
=
crossentropy_softmax_1hot_with_bias
(
x
,
b
,
y_idx
)
#dx = CrossentropySoftmax1HotWithBiasDx()(g_nll, sm, y_idx)
#dx = CrossentropySoftmax1HotWithBiasDx()(g_nll, sm, y_idx)
dx
=
crossentropy_softmax_1hot_with_bias_dx
(
g_nll
,
sm
,
y_idx
)
dx
=
crossentropy_softmax_1hot_with_bias_dx
(
g_nll
,
sm
,
y_idx
)
db
=
tensor
.
sum
(
dx
,
axis
=
[
0
])
db
=
tensor
.
sum
(
dx
,
axis
=
[
0
])
return
dx
,
db
,
None
return
dx
,
db
,
None
def
c_headers
(
self
):
def
c_headers
(
self
):
...
@@ -706,13 +754,16 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
...
@@ -706,13 +754,16 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
&& (
%(y_idx)
s->descr->type_num != PyArray_INT16)
&& (
%(y_idx)
s->descr->type_num != PyArray_INT16)
&& (
%(y_idx)
s->descr->type_num != PyArray_INT8))
&& (
%(y_idx)
s->descr->type_num != PyArray_INT8))
{
{
PyErr_SetString(PyExc_TypeError, "y_idx not int8, int16, int32, or int64");
PyErr_SetString(PyExc_TypeError,
"y_idx not int8, int16, int32, or int64");
%(fail)
s;
%(fail)
s;
}
}
if (
%(x)
s->dimensions[0] !=
%(y_idx)
s->dimensions[0])
if (
%(x)
s->dimensions[0] !=
%(y_idx)
s->dimensions[0])
{
{
PyErr_Format(PyExc_ValueError, "number of rows in x (
%%
ld) does not match length of y (
%%
ld)",
PyErr_Format(PyExc_ValueError,
(long int)
%(x)
s->dimensions[0], (long int)
%(y_idx)
s->dimensions[0]);
"number of rows in x (
%%
ld) does not match length of y (
%%
ld)",
(long int)
%(x)
s->dimensions[0],
(long int)
%(y_idx)
s->dimensions[0]);
%(fail)
s;
%(fail)
s;
}
}
...
@@ -720,10 +771,12 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
...
@@ -720,10 +771,12 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
|| (
%(nll)
s->dimensions[0] !=
%(y_idx)
s->dimensions[0]))
|| (
%(nll)
s->dimensions[0] !=
%(y_idx)
s->dimensions[0]))
{
{
if (NULL !=
%(nll)
s) Py_XDECREF(
%(nll)
s);
if (NULL !=
%(nll)
s) Py_XDECREF(
%(nll)
s);
%(nll)
s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(
%(y_idx)
s), type_num_
%(x)
s);
%(nll)
s = (PyArrayObject*)PyArray_SimpleNew(1,
PyArray_DIMS(
%(y_idx)
s), type_num_
%(x)
s);
if(!
%(nll)
s)
if(!
%(nll)
s)
{
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc nll output");
PyErr_SetString(PyExc_MemoryError,
"failed to alloc nll output");
%(fail)
s;
%(fail)
s;
}
}
}
}
...
@@ -731,18 +784,20 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
...
@@ -731,18 +784,20 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
|| (
%(am)
s->dimensions[0] !=
%(y_idx)
s->dimensions[0]))
|| (
%(am)
s->dimensions[0] !=
%(y_idx)
s->dimensions[0]))
{
{
Py_XDECREF(
%(am)
s);
Py_XDECREF(
%(am)
s);
%(am)
s = (PyArrayObject*) PyArray_SimpleNew(1, PyArray_DIMS(
%(y_idx)
s), type_num_
%(y_idx)
s);
%(am)
s = (PyArrayObject*) PyArray_SimpleNew(1,
PyArray_DIMS(
%(y_idx)
s), type_num_
%(y_idx)
s);
if(!
%(am)
s)
if(!
%(am)
s)
{
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc am output");
PyErr_SetString(PyExc_MemoryError,
"failed to alloc am output");
%(fail)
s;
%(fail)
s;
}
}
}
}
"""
,
"""
,
begin_row_loop
,
begin_row_loop
,
"""
"""
const
%(y_idx_type)
s y_i = ((
%(y_idx_type)
s*)(
%(y_idx)
s->data +
%(y_idx)
s->strides[0] * i))[0];
const
%(y_idx_type)
s y_i = ((
%(y_idx_type)
s*)(
%(y_idx)
s->data +
%(y_idx)
s->strides[0] * i))[0];
dtype_
%(nll)
s* __restrict__ nll_i = (dtype_
%(nll)
s*)(
%(nll)
s->data +
%(nll)
s->strides[0] * i);
dtype_
%(nll)
s* __restrict__ nll_i = (dtype_
%(nll)
s*)(
%(nll)
s->data +
%(nll)
s->strides[0] * i);
%(am_type)
s* __restrict__ am_i = (
%(am_type)
s*) (
%(am)
s->data +
%(am)
s->strides[0] * i);
%(am_type)
s* __restrict__ am_i = (
%(am_type)
s*) (
%(am)
s->data +
%(am)
s->strides[0] * i);
"""
,
"""
,
inside_row_loop
,
inside_row_loop
,
...
@@ -760,9 +815,9 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
...
@@ -760,9 +815,9 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
"""
,
"""
,
end_row_loop
)
end_row_loop
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
5
,)
+
SoftmaxWithBias
.
c_code_cache_version
()
return
(
5
,)
+
SoftmaxWithBias
.
c_code_cache_version
()
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
x
,
b
,
y_idx
=
inp
x
,
b
,
y_idx
=
inp
nll
,
sm
,
am
=
out
nll
,
sm
,
am
=
out
...
@@ -771,30 +826,37 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
...
@@ -771,30 +826,37 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
code_template
=
''
.
join
(
self
.
c_code_template
())
code_template
=
''
.
join
(
self
.
c_code_template
())
return
code_template
%
dict
(
locals
(),
**
sub
)
return
code_template
%
dict
(
locals
(),
**
sub
)
class
CrossentropySoftmax1HotWithBiasDx
(
gof
.
Op
):
class
CrossentropySoftmax1HotWithBiasDx
(
gof
.
Op
):
nin
=
3
nin
=
3
nout
=
1
nout
=
1
"""Gradient wrt x of the CrossentropySoftmax
1Hot
Op"""
"""Gradient wrt x of the CrossentropySoftmax
Argmax1HotWithBias
Op"""
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
gof
.
Op
.
__init__
(
self
,
**
kwargs
)
gof
.
Op
.
__init__
(
self
,
**
kwargs
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
return
tensor
.
hashtype
(
self
)
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
def
make_node
(
self
,
dy
,
sm
,
y_idx
,
**
kwargs
):
def
make_node
(
self
,
dy
,
sm
,
y_idx
,
**
kwargs
):
dy
=
tensor
.
as_tensor_variable
(
dy
)
dy
=
tensor
.
as_tensor_variable
(
dy
)
sm
=
tensor
.
as_tensor_variable
(
sm
)
sm
=
tensor
.
as_tensor_variable
(
sm
)
y_idx
=
tensor
.
as_tensor_variable
(
y_idx
)
y_idx
=
tensor
.
as_tensor_variable
(
y_idx
)
return
Apply
(
self
,
[
dy
,
sm
,
y_idx
],[
sm
.
type
.
make_variable
()])
return
Apply
(
self
,
[
dy
,
sm
,
y_idx
],
[
sm
.
type
.
make_variable
()])
def
perform
(
self
,
node
,
input_storage
,
output_storage
):
def
perform
(
self
,
node
,
input_storage
,
output_storage
):
dy
,
sm
,
y_idx
=
input_storage
dy
,
sm
,
y_idx
=
input_storage
dx
=
numpy
.
zeros_like
(
sm
)
dx
=
numpy
.
zeros_like
(
sm
)
for
i
in
xrange
(
sm
.
shape
[
0
]):
for
i
in
xrange
(
sm
.
shape
[
0
]):
dx
[
i
]
=
dy
[
i
]
*
sm
[
i
]
#
vector scale
dx
[
i
]
=
dy
[
i
]
*
sm
[
i
]
#
vector scale
dx
[
i
,
y_idx
[
i
]]
-=
dy
[
i
]
#
scalar decrement
dx
[
i
,
y_idx
[
i
]]
-=
dy
[
i
]
#
scalar decrement
output_storage
[
0
][
0
]
=
dx
output_storage
[
0
][
0
]
=
dx
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
dy
,
sm
,
y_idx
=
inp
dy
,
sm
,
y_idx
=
inp
g_dx
,
=
grads
g_dx
,
=
grads
...
@@ -810,22 +872,28 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
...
@@ -810,22 +872,28 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
g_sm
=
dy
.
dimshuffle
(
0
,
'x'
)
*
g_dx
g_sm
=
dy
.
dimshuffle
(
0
,
'x'
)
*
g_dx
g_y_idx
=
None
g_y_idx
=
None
return
[
g_dy
,
g_sm
,
g_y_idx
]
return
[
g_dy
,
g_sm
,
g_y_idx
]
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
2
,)
return
(
2
,)
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
dnll
,
sm
,
y_idx
=
inp
dnll
,
sm
,
y_idx
=
inp
dx
,
=
out
dx
,
=
out
y_idx_type
=
node
.
inputs
[
2
]
.
type
.
dtype_specs
()[
1
]
y_idx_type
=
node
.
inputs
[
2
]
.
type
.
dtype_specs
()[
1
]
return
"""
return
"""
if ((
%(dnll)
s->descr->type_num != PyArray_DOUBLE) && (
%(dnll)
s->descr->type_num != PyArray_FLOAT))
if ((
%(dnll)
s->descr->type_num != PyArray_DOUBLE) &&
(
%(dnll)
s->descr->type_num != PyArray_FLOAT))
{
{
PyErr_SetString(PyExc_TypeError, "dnll type should be float32 or float64");
PyErr_SetString(PyExc_TypeError,
"dnll type should be float32 or float64");
%(fail)
s;
%(fail)
s;
}
}
if ((
%(sm)
s->descr->type_num != PyArray_DOUBLE) && (
%(sm)
s->descr->type_num != PyArray_FLOAT))
if ((
%(sm)
s->descr->type_num != PyArray_DOUBLE) &&
(
%(sm)
s->descr->type_num != PyArray_FLOAT))
{
{
PyErr_SetString(PyExc_TypeError, "sm type should be float32 or float64");
PyErr_SetString(PyExc_TypeError,
"sm type should be float32 or float64");
%(fail)
s;
%(fail)
s;
}
}
if ((
%(y_idx)
s->descr->type_num != PyArray_INT64)
if ((
%(y_idx)
s->descr->type_num != PyArray_INT64)
...
@@ -833,7 +901,8 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
...
@@ -833,7 +901,8 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
&& (
%(y_idx)
s->descr->type_num != PyArray_INT16)
&& (
%(y_idx)
s->descr->type_num != PyArray_INT16)
&& (
%(y_idx)
s->descr->type_num != PyArray_INT8))
&& (
%(y_idx)
s->descr->type_num != PyArray_INT8))
{
{
PyErr_SetString(PyExc_TypeError, "y_idx not int8, int16, int32, or int64");
PyErr_SetString(PyExc_TypeError,
"y_idx not int8, int16, int32, or int64");
%(fail)
s;
%(fail)
s;
}
}
if ((
%(dnll)
s->nd != 1)
if ((
%(dnll)
s->nd != 1)
...
@@ -845,14 +914,18 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
...
@@ -845,14 +914,18 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
}
}
if (
%(dnll)
s->dimensions[0] !=
%(sm)
s->dimensions[0])
if (
%(dnll)
s->dimensions[0] !=
%(sm)
s->dimensions[0])
{
{
PyErr_Format(PyExc_ValueError, "dnll.shape[0] (
%%
ld) != sm.shape[0] (
%%
ld)",
PyErr_Format(PyExc_ValueError,
(long int)
%(dnll)
s->dimensions[0], (long int)
%(sm)
s->dimensions[0]);
"dnll.shape[0] (
%%
ld) != sm.shape[0] (
%%
ld)",
(long int)
%(dnll)
s->dimensions[0],
(long int)
%(sm)
s->dimensions[0]);
%(fail)
s;
%(fail)
s;
}
}
if (
%(dnll)
s->dimensions[0] !=
%(y_idx)
s->dimensions[0])
if (
%(dnll)
s->dimensions[0] !=
%(y_idx)
s->dimensions[0])
{
{
PyErr_Format(PyExc_ValueError, "dnll.shape[0] (
%%
ld) != y_idx.shape[0] (
%%
ld)",
PyErr_Format(PyExc_ValueError,
(long int)
%(dnll)
s->dimensions[0], (long int)
%(y_idx)
s->dimensions[0]);
"dnll.shape[0] (
%%
ld) != y_idx.shape[0] (
%%
ld)",
(long int)
%(dnll)
s->dimensions[0],
(long int)
%(y_idx)
s->dimensions[0]);
%(fail)
s;
%(fail)
s;
}
}
if ((NULL ==
%(dx)
s)
if ((NULL ==
%(dx)
s)
...
@@ -860,9 +933,12 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
...
@@ -860,9 +933,12 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
|| (
%(dx)
s->dimensions[1] !=
%(sm)
s->dimensions[1]))
|| (
%(dx)
s->dimensions[1] !=
%(sm)
s->dimensions[1]))
{
{
if (NULL !=
%(dx)
s) Py_XDECREF(
%(dx)
s);
if (NULL !=
%(dx)
s) Py_XDECREF(
%(dx)
s);
%(dx)
s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(
%(sm)
s), type_num_
%(sm)
s);
%(dx)
s = (PyArrayObject*) PyArray_SimpleNew(2,
PyArray_DIMS(
%(sm)
s),
type_num_
%(sm)
s);
if(!
%(dx)
s) {
if(!
%(dx)
s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dx output");
PyErr_SetString(PyExc_MemoryError,
"failed to alloc dx output");
%(fail)
s
%(fail)
s
}
}
}
}
...
@@ -871,12 +947,12 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
...
@@ -871,12 +947,12 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
{
{
const dtype_
%(dnll)
s dnll_i = ((dtype_
%(dnll)
s*)(
%(dnll)
s->data +
%(dnll)
s->strides[0] * i))[0];
const dtype_
%(dnll)
s dnll_i = ((dtype_
%(dnll)
s*)(
%(dnll)
s->data +
%(dnll)
s->strides[0] * i))[0];
const
%(y_idx_type)
s y_i = ((
%(y_idx_type)
s*)(
%(y_idx)
s->data +
%(y_idx)
s->strides[0] * i))[0];
const
%(y_idx_type)
s y_i = ((
%(y_idx_type)
s*)(
%(y_idx)
s->data +
%(y_idx)
s->strides[0] * i))[0];
const dtype_
%(sm)
s* __restrict__ sm_i = (dtype_
%(sm)
s*)(
%(sm)
s->data +
%(sm)
s->strides[0] * i);
const dtype_
%(sm)
s* __restrict__ sm_i = (dtype_
%(sm)
s*)(
%(sm)
s->data +
%(sm)
s->strides[0] * i);
npy_intp Ssm =
%(sm)
s->strides[1]/sizeof(dtype_
%(sm)
s);
npy_intp Ssm =
%(sm)
s->strides[1]/sizeof(dtype_
%(sm)
s);
dtype_
%(dx)
s* __restrict__ dx_i = (dtype_
%(dx)
s*)(
%(dx)
s->data +
%(dx)
s->strides[0] * i);
dtype_
%(dx)
s* __restrict__ dx_i = (dtype_
%(dx)
s*)(
%(dx)
s->data +
%(dx)
s->strides[0] * i);
npy_intp Sdx =
%(dx)
s->strides[1]/sizeof(dtype_
%(dx)
s);
npy_intp Sdx =
%(dx)
s->strides[1]/sizeof(dtype_
%(dx)
s);
for (size_t j = 0; j <
%(dx)
s->dimensions[1]; ++j)
for (size_t j = 0; j <
%(dx)
s->dimensions[1]; ++j)
...
@@ -898,48 +974,68 @@ crossentropy_softmax_argmax_1hot_with_bias = \
...
@@ -898,48 +974,68 @@ crossentropy_softmax_argmax_1hot_with_bias = \
crossentropy_softmax_1hot_with_bias_dx
=
\
crossentropy_softmax_1hot_with_bias_dx
=
\
CrossentropySoftmax1HotWithBiasDx
()
CrossentropySoftmax1HotWithBiasDx
()
def
crossentropy_softmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
):
def
crossentropy_softmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
):
return
crossentropy_softmax_argmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
)[
0
:
2
]
return
crossentropy_softmax_argmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
)[
0
:
2
]
def
crossentropy_softmax_1hot
(
x
,
y_idx
,
**
kwargs
):
def
crossentropy_softmax_1hot
(
x
,
y_idx
,
**
kwargs
):
b
=
tensor
.
zeros_like
(
x
[
0
,:])
b
=
tensor
.
zeros_like
(
x
[
0
,
:])
return
crossentropy_softmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
)
return
crossentropy_softmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
)
def
crossentropy_softmax_max_and_argmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
):
def
crossentropy_softmax_max_and_argmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
):
"""
"""
@return: The cross-entropy, the softmax output, the max probability, and the argmax index
@return: The cross-entropy, the softmax output, the max probability,
@todo: Since we are recomputing the argmax, we might as well assert that it is correct.
and the argmax index
@todo: Since we are recomputing the argmax,
we might as well assert that it is correct.
@todo: Make this entire function is
@todo: Make this entire function is
unnecessary? e.g. CrossentropySoftmaxArgmax1HotWithBias should return
unnecessary? e.g. CrossentropySoftmaxArgmax1HotWithBias should return
the appropriate information (i.e. the max probability)?
the appropriate information (i.e. the max probability)?
"""
"""
(
xent
,
softmax
)
=
crossentropy_softmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
)
(
xent
,
softmax
)
=
crossentropy_softmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
)
(
max_pr
,
argmax
)
=
tensor
.
max_and_argmax
(
softmax
,
axis
=-
1
)
(
max_pr
,
argmax
)
=
tensor
.
max_and_argmax
(
softmax
,
axis
=-
1
)
return
(
xent
,
softmax
,
max_pr
,
argmax
)
return
(
xent
,
softmax
,
max_pr
,
argmax
)
def
crossentropy_softmax_max_and_argmax_1hot
(
x
,
y_idx
,
**
kwargs
):
def
crossentropy_softmax_max_and_argmax_1hot
(
x
,
y_idx
,
**
kwargs
):
b
=
tensor
.
zeros_like
(
x
[
0
,:])
b
=
tensor
.
zeros_like
(
x
[
0
,
:])
return
crossentropy_softmax_max_and_argmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
)
return
crossentropy_softmax_max_and_argmax_1hot_with_bias
(
x
,
b
,
y_idx
,
**
kwargs
)
class
CrossentropyCategorical1HotGrad
(
gof
.
Op
):
class
CrossentropyCategorical1HotGrad
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
return
tensor
.
hashtype
(
self
)
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
def
make_node
(
self
,
g_y
,
coding_dist
,
true_one_of_n
):
def
make_node
(
self
,
g_y
,
coding_dist
,
true_one_of_n
):
return
Apply
(
self
,
[
g_y
,
coding_dist
,
true_one_of_n
],
[
coding_dist
.
type
()])
return
Apply
(
self
,
[
g_y
,
coding_dist
,
true_one_of_n
],
[
coding_dist
.
type
()])
def
perform
(
self
,
node
,
inp
,
out
):
def
perform
(
self
,
node
,
inp
,
out
):
g_y
,
coding_dist
,
true_one_of_n
=
inp
g_y
,
coding_dist
,
true_one_of_n
=
inp
g_coding_strg
,
=
out
g_coding_strg
,
=
out
g_coding
=
numpy
.
zeros_like
(
coding_dist
)
g_coding
=
numpy
.
zeros_like
(
coding_dist
)
for
i
in
xrange
(
len
(
g_y
)):
for
i
in
xrange
(
len
(
g_y
)):
g_coding
[
i
,
true_one_of_n
[
i
]]
=
-
g_y
[
i
]
/
coding_dist
[
i
,
true_one_of_n
[
i
]]
g_coding
[
i
,
true_one_of_n
[
i
]]
=
-
g_y
[
i
]
/
coding_dist
[
i
,
true_one_of_n
[
i
]]
g_coding_strg
[
0
]
=
g_coding
g_coding_strg
[
0
]
=
g_coding
crossentropy_categorical_1hot_grad
=
CrossentropyCategorical1HotGrad
()
crossentropy_categorical_1hot_grad
=
CrossentropyCategorical1HotGrad
()
class
CrossentropyCategorical1Hot
(
gof
.
Op
):
class
CrossentropyCategorical1Hot
(
gof
.
Op
):
"""Compute the cross entropy between a coding distribution and
"""Compute the cross entropy between a coding distribution and
...
@@ -950,18 +1046,21 @@ class CrossentropyCategorical1Hot(gof.Op):
...
@@ -950,18 +1046,21 @@ class CrossentropyCategorical1Hot(gof.Op):
y[i] = -
\
log(coding_dist[i, one_of_n[i])
y[i] = -
\
log(coding_dist[i, one_of_n[i])
:note:
:note:
In the case that the coding distribution is the output of a
In the case that the coding distribution is the output of a softmax, an application of this
softmax, an application of this Op will probably be optimized
Op will probably be optimized
away in favour of one with a C implementation.
away in favour of one with a C implementation.
"""
"""
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
return
tensor
.
hashtype
(
self
)
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
def
make_node
(
self
,
coding_dist
,
true_one_of_n
):
def
make_node
(
self
,
coding_dist
,
true_one_of_n
):
"""
"""
:type coding_dist: dense matrix
:type coding_dist: dense matrix
...
@@ -975,17 +1074,19 @@ class CrossentropyCategorical1Hot(gof.Op):
...
@@ -975,17 +1074,19 @@ class CrossentropyCategorical1Hot(gof.Op):
if
_coding_dist
.
type
.
ndim
!=
2
:
if
_coding_dist
.
type
.
ndim
!=
2
:
raise
TypeError
(
'matrix required for argument: coding_dist'
)
raise
TypeError
(
'matrix required for argument: coding_dist'
)
if
_true_one_of_n
.
type
not
in
(
tensor
.
lvector
,
tensor
.
ivector
):
if
_true_one_of_n
.
type
not
in
(
tensor
.
lvector
,
tensor
.
ivector
):
raise
TypeError
(
'integer vector required for argument: true_one_of_n'
raise
TypeError
(
'(got type:
%
s instead of:
%
s)'
%
(
_true_one_of_n
.
type
,
'integer vector required for argument: true_one_of_n'
tensor
.
lvector
))
'(got type:
%
s instead of:
%
s)'
%
(
_true_one_of_n
.
type
,
tensor
.
lvector
))
return
Apply
(
self
,
[
_coding_dist
,
_true_one_of_n
],
return
Apply
(
self
,
[
_coding_dist
,
_true_one_of_n
],
[
tensor
.
Tensor
(
dtype
=
_coding_dist
.
dtype
,
broadcastable
=
[
False
])()])
[
tensor
.
Tensor
(
dtype
=
_coding_dist
.
dtype
,
broadcastable
=
[
False
])()])
def
perform
(
self
,
node
,
inp
,
out
):
def
perform
(
self
,
node
,
inp
,
out
):
coding
,
one_of_n
=
inp
coding
,
one_of_n
=
inp
y_out
,
=
out
y_out
,
=
out
y
=
numpy
.
zeros_like
(
coding
[:,
0
])
y
=
numpy
.
zeros_like
(
coding
[:,
0
])
for
i
in
xrange
(
len
(
y
)):
for
i
in
xrange
(
len
(
y
)):
y
[
i
]
=
-
numpy
.
log
(
coding
[
i
,
one_of_n
[
i
]])
y
[
i
]
=
-
numpy
.
log
(
coding
[
i
,
one_of_n
[
i
]])
y_out
[
0
]
=
y
y_out
[
0
]
=
y
...
@@ -993,18 +1094,21 @@ class CrossentropyCategorical1Hot(gof.Op):
...
@@ -993,18 +1094,21 @@ class CrossentropyCategorical1Hot(gof.Op):
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
coding
,
one_of_n
=
inp
coding
,
one_of_n
=
inp
g_y
,
=
grads
g_y
,
=
grads
return
[
crossentropy_categorical_1hot_grad
(
g_y
,
coding
,
one_of_n
),
None
]
return
[
crossentropy_categorical_1hot_grad
(
g_y
,
coding
,
one_of_n
),
None
]
crossentropy_categorical_1hot
=
CrossentropyCategorical1Hot
()
crossentropy_categorical_1hot
=
CrossentropyCategorical1Hot
()
@opt.register_stabilize
@opt.register_stabilize
@opt.register_specialize
@opt.register_specialize
@gof.optimizer
@gof.optimizer
def
crossentropy_to_crossentropy_with_softmax_with_bias
(
fgraph
):
def
crossentropy_to_crossentropy_with_softmax_with_bias
(
fgraph
):
"""
"""This is a stabilization optimization
This is a stabilization optimization
:note: not a local optimization because we are replacing outputs
from several nodes at once
..note: not a local optimization because we are replacing outputs from several nodes at once
"""
"""
def
search_make_one_sub
():
def
search_make_one_sub
():
...
@@ -1016,7 +1120,7 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
...
@@ -1016,7 +1120,7 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
x
,
b
=
sm
.
owner
.
inputs
x
,
b
=
sm
.
owner
.
inputs
new_nll
,
new_sm
,
new_am
=
crossentropy_softmax_argmax_1hot_with_bias
(
x
,
b
,
new_nll
,
new_sm
,
new_am
=
crossentropy_softmax_argmax_1hot_with_bias
(
x
,
b
,
one_of_n
)
one_of_n
)
fgraph
.
replace_all_validate
([(
nll
,
new_nll
),(
sm
,
new_sm
)],
fgraph
.
replace_all_validate
([(
nll
,
new_nll
),
(
sm
,
new_sm
)],
reason
=
"crossentropy_to_crossentropy_with_softmax"
)
reason
=
"crossentropy_to_crossentropy_with_softmax"
)
return
True
return
True
...
@@ -1026,16 +1130,20 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
...
@@ -1026,16 +1130,20 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
pass
pass
return
return
@gof.optimizer
@gof.optimizer
def
crossentropy_to_crossentropy_with_softmax
(
fgraph
):
def
crossentropy_to_crossentropy_with_softmax
(
fgraph
):
"""
"""
This is a stabilization optimization that is more general then
This is a stabilization optimization that is more general then
crossentropy_to_crossentropy_with_softmax_with_bias
crossentropy_to_crossentropy_with_softmax_with_bias
It must be executed after local_softmax_with_bias optimization in specialize
It must be executed after local_softmax_with_bias optimization in
specialize
: todo: This is a stabilization optimization! How to make this more cleanly?
:todo: This is a stabilization optimization! How to make this more cleanly?
:note: not a local optimization because we are replacing outputs
from several nodes at once
..note: not a local optimization because we are replacing outputs from several nodes at once
"""
"""
def
search_make_one_sub
():
def
search_make_one_sub
():
...
@@ -1047,14 +1155,14 @@ def crossentropy_to_crossentropy_with_softmax(fgraph):
...
@@ -1047,14 +1155,14 @@ def crossentropy_to_crossentropy_with_softmax(fgraph):
x
,
=
sm
.
owner
.
inputs
x
,
=
sm
.
owner
.
inputs
new_nll
,
new_sm
,
new_am
=
crossentropy_softmax_argmax_1hot_with_bias
(
x
,
new_nll
,
new_sm
,
new_am
=
crossentropy_softmax_argmax_1hot_with_bias
(
x
,
tensor
.
zeros_like
(
x
[
0
]),
one_of_n
)
tensor
.
zeros_like
(
x
[
0
]),
one_of_n
)
fgraph
.
replace_all_validate
([(
nll
,
new_nll
),(
sm
,
new_sm
)],
fgraph
.
replace_all_validate
([(
nll
,
new_nll
),
(
sm
,
new_sm
)],
reason
=
"crossentropy_to_crossentropy_with_softmax"
)
reason
=
"crossentropy_to_crossentropy_with_softmax"
)
return
True
return
True
if
sm
.
owner
and
sm
.
owner
.
op
==
softmax_with_bias
:
if
sm
.
owner
and
sm
.
owner
.
op
==
softmax_with_bias
:
x
,
b
=
sm
.
owner
.
inputs
x
,
b
=
sm
.
owner
.
inputs
new_nll
,
new_sm
,
new_am
=
crossentropy_softmax_argmax_1hot_with_bias
(
x
,
b
,
new_nll
,
new_sm
,
new_am
=
crossentropy_softmax_argmax_1hot_with_bias
(
x
,
b
,
one_of_n
)
one_of_n
)
fgraph
.
replace_all_validate
([(
nll
,
new_nll
),(
sm
,
new_sm
)],
fgraph
.
replace_all_validate
([(
nll
,
new_nll
),
(
sm
,
new_sm
)],
reason
=
"crossentropy_to_crossentropy_with_softmax"
)
reason
=
"crossentropy_to_crossentropy_with_softmax"
)
return
True
return
True
...
@@ -1064,24 +1172,29 @@ def crossentropy_to_crossentropy_with_softmax(fgraph):
...
@@ -1064,24 +1172,29 @@ def crossentropy_to_crossentropy_with_softmax(fgraph):
pass
pass
return
return
optdb
.
register
(
'crossentropy_to_crossentropy_with_softmax'
,
crossentropy_to_crossentropy_with_softmax
,
2.01
,
optdb
.
register
(
'crossentropy_to_crossentropy_with_softmax'
,
'fast_run'
,
'xent'
)
crossentropy_to_crossentropy_with_softmax
,
2.01
,
'fast_run'
,
'xent'
)
@gof.local_optimizer
([
softmax_grad
])
@gof.local_optimizer
([
softmax_grad
])
def
local_crossentropy_to_crossentropy_with_softmax_grad
(
node
):
def
local_crossentropy_to_crossentropy_with_softmax_grad
(
node
):
if
node
.
op
==
softmax_grad
:
if
node
.
op
==
softmax_grad
:
g_coding_dist
,
coding_dist
=
node
.
inputs
g_coding_dist
,
coding_dist
=
node
.
inputs
if
g_coding_dist
.
owner
and
g_coding_dist
.
owner
.
op
==
crossentropy_categorical_1hot_grad
:
if
(
g_coding_dist
.
owner
and
g_coding_dist
.
owner
.
op
==
crossentropy_categorical_1hot_grad
):
g_nll
,
coding_dist
,
true_one_of_n
=
g_coding_dist
.
owner
.
inputs
g_nll
,
coding_dist
,
true_one_of_n
=
g_coding_dist
.
owner
.
inputs
dx
=
crossentropy_softmax_1hot_with_bias_dx
(
g_nll
,
coding_dist
,
true_one_of_n
)
dx
=
crossentropy_softmax_1hot_with_bias_dx
(
g_nll
,
coding_dist
,
true_one_of_n
)
return
[
dx
]
return
[
dx
]
opt
.
register_specialize
(
local_crossentropy_to_crossentropy_with_softmax_grad
)
opt
.
register_specialize
(
local_crossentropy_to_crossentropy_with_softmax_grad
)
@opt.register_specialize
@opt.register_specialize
@gof.local_optimizer
([
tensor
.
_max_and_argmax
])
@gof.local_optimizer
([
tensor
.
_max_and_argmax
])
def
local_argmax_pushdown
(
node
):
def
local_argmax_pushdown
(
node
):
if
node
.
op
==
tensor
.
_max_and_argmax
and
node
.
inputs
[
0
]
.
owner
and
\
if
node
.
op
==
tensor
.
_max_and_argmax
and
node
.
inputs
[
0
]
.
owner
and
\
len
(
node
.
outputs
[
0
]
.
clients
)
>
0
and
node
.
inputs
[
0
]
.
owner
.
op
in
\
len
(
node
.
outputs
[
0
]
.
clients
)
>
0
and
node
.
inputs
[
0
]
.
owner
.
op
in
\
(
softmax
,
softplus
,
tensor
.
exp
,
tensor
.
log
,
tensor
.
tanh
,
sigmoid
,
(
softmax
,
softplus
,
tensor
.
exp
,
tensor
.
log
,
tensor
.
tanh
,
sigmoid
,
softmax_with_bias
):
softmax_with_bias
):
if
theano
.
config
.
warn
.
argmax_pushdown_bug
:
if
theano
.
config
.
warn
.
argmax_pushdown_bug
:
...
@@ -1093,28 +1206,33 @@ def local_argmax_pushdown(node):
...
@@ -1093,28 +1206,33 @@ def local_argmax_pushdown(node):
"warning set the Theano flags 'warn.argmax_pushdown_bug' "
"warning set the Theano flags 'warn.argmax_pushdown_bug' "
"to False"
)
"to False"
)
if
node
.
op
==
tensor
.
_max_and_argmax
and
node
.
inputs
[
0
]
.
owner
and
len
(
node
.
outputs
[
0
]
.
clients
)
==
0
:
if
(
node
.
op
==
tensor
.
_max_and_argmax
and
node
.
inputs
[
0
]
.
owner
and
len
(
node
.
outputs
[
0
]
.
clients
)
==
0
):
x_max
,
x_argmax
=
node
.
outputs
x_max
,
x_argmax
=
node
.
outputs
x
,
axis
=
node
.
inputs
x
,
axis
=
node
.
inputs
#TODO: Make a list/set of monotonic ops...
#TODO: Make a list/set of monotonic ops...
if
x
.
owner
and
x
.
owner
.
op
in
(
softmax
,
softplus
,
tensor
.
exp
,
tensor
.
log
,
tensor
.
tanh
,
if
x
.
owner
and
x
.
owner
.
op
in
(
softmax
,
softplus
,
tensor
.
exp
,
sigmoid
):
tensor
.
log
,
tensor
.
tanh
,
sigmoid
):
pre_x
,
=
x
.
owner
.
inputs
pre_x
,
=
x
.
owner
.
inputs
return
tensor
.
_max_and_argmax
(
pre_x
,
axis
)
return
tensor
.
_max_and_argmax
(
pre_x
,
axis
)
if
x
.
owner
and
x
.
owner
.
op
==
softmax_with_bias
:
if
x
.
owner
and
x
.
owner
.
op
==
softmax_with_bias
:
pre_x
,
pre_bias
=
x
.
owner
.
inputs
pre_x
,
pre_bias
=
x
.
owner
.
inputs
return
tensor
.
_max_and_argmax
(
pre_x
+
tensor
.
DimShuffle
(
pre_bias
.
broadcastable
,
return
tensor
.
_max_and_argmax
(
pre_x
+
(
'x'
,
0
))(
pre_bias
),
axis
)
tensor
.
DimShuffle
(
pre_bias
.
broadcastable
,
(
'x'
,
0
))(
pre_bias
),
axis
)
# Utility function used by the two next optimizations
# Utility function used by the two next optimizations
def
_check_rows_is_arange_len_labels
(
rows
,
labels
):
def
_check_rows_is_arange_len_labels
(
rows
,
labels
):
'''Check that 'rows' is the same node as T.arange(labels.shape[0])'''
'''Check that 'rows' is the same node as T.arange(labels.shape[0])'''
if
rows
.
owner
and
isinstance
(
rows
.
owner
.
op
,
tensor
.
ARange
):
if
rows
.
owner
and
isinstance
(
rows
.
owner
.
op
,
tensor
.
ARange
):
start
,
stop
,
step
=
rows
.
owner
.
inputs
start
,
stop
,
step
=
rows
.
owner
.
inputs
if
getattr
(
start
,
'data'
,
None
)
!=
0
:
#
constants will have data
if
getattr
(
start
,
'data'
,
None
)
!=
0
:
#
constants will have data
return
False
return
False
if
getattr
(
step
,
'data'
,
None
)
!=
1
:
# constant step will have data
if
getattr
(
step
,
'data'
,
None
)
!=
1
:
# constant step will have data
return
False
return
False
if
not
stop
.
owner
:
if
not
stop
.
owner
:
return
False
return
False
...
@@ -1131,15 +1249,18 @@ def _check_rows_is_arange_len_labels(rows, labels):
...
@@ -1131,15 +1249,18 @@ def _check_rows_is_arange_len_labels(rows, labels):
shape_of
=
stop
.
owner
.
fgraph
.
shape_feature
.
shape_of
shape_of
=
stop
.
owner
.
fgraph
.
shape_feature
.
shape_of
return
shape_of
[
labels
][
0
]
is
stop
return
shape_of
[
labels
][
0
]
is
stop
def
_is_const
(
z
,
val
,
approx
=
False
):
def
_is_const
(
z
,
val
,
approx
=
False
):
try
:
try
:
maybe
=
opt
.
get_constant_value
(
z
)
maybe
=
opt
.
get_constant_value
(
z
)
except
TypeError
:
except
TypeError
:
return
False
return
False
if
approx
:
if
approx
:
return
numpy
.
allclose
(
maybe
,
val
)
return
numpy
.
allclose
(
maybe
,
val
)
else
:
else
:
return
numpy
.
all
(
maybe
==
val
)
return
numpy
.
all
(
maybe
==
val
)
@opt.register_specialize
@opt.register_specialize
@gof.local_optimizer
([])
@gof.local_optimizer
([])
def
local_advanced_indexing_crossentropy_onehot
(
node
):
def
local_advanced_indexing_crossentropy_onehot
(
node
):
...
@@ -1164,7 +1285,8 @@ def local_advanced_indexing_crossentropy_onehot(node):
...
@@ -1164,7 +1285,8 @@ def local_advanced_indexing_crossentropy_onehot(node):
pass
pass
if
sm
is
not
None
and
sm
.
owner
and
sm
.
owner
.
op
in
(
softmax
,
softmax_with_bias
):
if
sm
is
not
None
and
sm
.
owner
and
sm
.
owner
.
op
in
(
softmax
,
softmax_with_bias
):
sm_w_bias
=
local_softmax_with_bias
.
transform
(
sm
.
owner
)
sm_w_bias
=
local_softmax_with_bias
.
transform
(
sm
.
owner
)
if
sm_w_bias
:
if
sm_w_bias
:
assert
sm_w_bias
[
0
]
.
owner
.
op
==
softmax_with_bias
assert
sm_w_bias
[
0
]
.
owner
.
op
==
softmax_with_bias
...
@@ -1176,7 +1298,10 @@ def local_advanced_indexing_crossentropy_onehot(node):
...
@@ -1176,7 +1298,10 @@ def local_advanced_indexing_crossentropy_onehot(node):
# Check that rows == arange(labels.shape[0])
# Check that rows == arange(labels.shape[0])
if
_check_rows_is_arange_len_labels
(
rows
,
labels
):
if
_check_rows_is_arange_len_labels
(
rows
,
labels
):
if
labels
.
ndim
==
1
and
x_var
.
ndim
==
2
:
if
labels
.
ndim
==
1
and
x_var
.
ndim
==
2
:
return
[
-
crossentropy_softmax_argmax_1hot_with_bias
(
x_var
,
b_var
,
labels
)[
0
]]
return
[
-
crossentropy_softmax_argmax_1hot_with_bias
(
x_var
,
b_var
,
labels
)[
0
]]
@opt.register_specialize
@opt.register_specialize
@gof.local_optimizer
([
softmax_grad
])
@gof.local_optimizer
([
softmax_grad
])
...
@@ -1190,7 +1315,8 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
...
@@ -1190,7 +1315,8 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
except
Exception
:
except
Exception
:
return
return
if
(
sm
is
not
None
)
and
sm
.
owner
and
(
sm
.
owner
.
op
in
(
softmax
,
softmax_with_bias
)):
if
(
sm
is
not
None
)
and
sm
.
owner
and
(
sm
.
owner
.
op
in
(
softmax
,
softmax_with_bias
)):
sm_w_bias
=
local_softmax_with_bias
.
transform
(
sm
.
owner
)
sm_w_bias
=
local_softmax_with_bias
.
transform
(
sm
.
owner
)
if
sm_w_bias
:
if
sm_w_bias
:
assert
sm_w_bias
[
0
]
.
owner
.
op
==
softmax_with_bias
assert
sm_w_bias
[
0
]
.
owner
.
op
==
softmax_with_bias
...
@@ -1276,7 +1402,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
...
@@ -1276,7 +1402,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# set out_grad according to the numerator, it may be divided later
# set out_grad according to the numerator, it may be divided later
# num should be a vector or a scalar
# num should be a vector or a scalar
if
num
.
ndim
==
1
or
numpy
.
all
(
num
.
broadcastable
):
if
num
.
ndim
==
1
or
numpy
.
all
(
num
.
broadcastable
):
out_grad
*=
-
num
out_grad
*=
-
num
else
:
else
:
return
return
...
@@ -1292,15 +1418,17 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
...
@@ -1292,15 +1418,17 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# Try to find the AdvancedSubtensor node mentionned above,
# Try to find the AdvancedSubtensor node mentionned above,
# and the output gradient
# and the output gradient
for
i
,
input
in
enumerate
(
denom
.
owner
.
inputs
):
for
i
,
input
in
enumerate
(
denom
.
owner
.
inputs
):
if
input
.
owner
and
isinstance
(
input
.
owner
.
op
,
tensor
.
AdvancedSubtensor
):
if
input
.
owner
and
isinstance
(
input
.
owner
.
op
,
other_inputs
=
[
in_
for
(
j
,
in_
)
in
enumerate
(
denom
.
owner
.
inputs
)
if
j
!=
i
]
tensor
.
AdvancedSubtensor
):
other_inputs
=
[
in_
for
(
j
,
in_
)
in
enumerate
(
denom
.
owner
.
inputs
)
if
j
!=
i
]
if
len
(
other_inputs
)
==
1
:
if
len
(
other_inputs
)
==
1
:
rest
=
other_inputs
[
0
]
rest
=
other_inputs
[
0
]
else
:
else
:
rest
=
tensor
.
mul
(
*
[
other_inputs
])
rest
=
tensor
.
mul
(
*
[
other_inputs
])
# Check that rest is a vector or a scalar
# Check that rest is a vector or a scalar
if
rest
.
ndim
==
1
or
numpy
.
all
(
rest
.
broadcastable
):
if
rest
.
ndim
==
1
or
numpy
.
all
(
rest
.
broadcastable
):
adv_subtensor
=
input
adv_subtensor
=
input
out_grad
/=
rest
out_grad
/=
rest
break
break
...
@@ -1308,7 +1436,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
...
@@ -1308,7 +1436,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
return
return
# The output gradient needs to be a vector
# The output gradient needs to be a vector
out_grad
=
tensor
.
fill
(
x_var
[:,
0
],
out_grad
)
out_grad
=
tensor
.
fill
(
x_var
[:,
0
],
out_grad
)
if
adv_subtensor
is
not
None
:
if
adv_subtensor
is
not
None
:
try
:
try
:
...
@@ -1316,7 +1444,9 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
...
@@ -1316,7 +1444,9 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
except
Exception
:
except
Exception
:
return
return
if
not
(
maybe_sm
is
sm
and
maybe_rows
is
rows
and
maybe_labels
is
labels
):
if
(
not
(
maybe_sm
is
sm
and
maybe_rows
is
rows
and
maybe_labels
is
labels
)):
return
return
#else: OK
#else: OK
else
:
else
:
...
@@ -1394,6 +1524,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
...
@@ -1394,6 +1524,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
else
:
else
:
return
return
@opt.register_specialize
@opt.register_specialize
@gof.local_optimizer
([
softmax_with_bias
])
@gof.local_optimizer
([
softmax_with_bias
])
def
graph_merge_softmax_with_crossentropy_softmax
(
node
):
def
graph_merge_softmax_with_crossentropy_softmax
(
node
):
...
@@ -1421,6 +1552,7 @@ def binary_crossentropy(output, target):
...
@@ -1421,6 +1552,7 @@ def binary_crossentropy(output, target):
"""
"""
return
-
(
target
*
tensor
.
log
(
output
)
+
(
1.0
-
target
)
*
tensor
.
log
(
1.0
-
output
))
return
-
(
target
*
tensor
.
log
(
output
)
+
(
1.0
-
target
)
*
tensor
.
log
(
1.0
-
output
))
def
categorical_crossentropy
(
coding_dist
,
true_dist
):
def
categorical_crossentropy
(
coding_dist
,
true_dist
):
"""
"""
WARNING: THIS FUNCTION IS UNNECESSARILY POLYMORPHIC.
WARNING: THIS FUNCTION IS UNNECESSARILY POLYMORPHIC.
...
@@ -1466,18 +1598,21 @@ def categorical_crossentropy(coding_dist, true_dist):
...
@@ -1466,18 +1598,21 @@ def categorical_crossentropy(coding_dist, true_dist):
from
theano
import
scalar
from
theano
import
scalar
class
Prepend_scalar_constant_to_each_row
(
gof
.
Op
):
class
Prepend_scalar_constant_to_each_row
(
gof
.
Op
):
def
__init__
(
self
,
val
=
0
):
def
__init__
(
self
,
val
=
0
):
if
isinstance
(
val
,
float
):
if
isinstance
(
val
,
float
):
val
=
scalar
.
constant
(
val
)
val
=
scalar
.
constant
(
val
)
self
.
val
=
val
self
.
val
=
val
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
and
(
self
.
val
==
other
.
val
)
return
(
type
(
self
)
==
type
(
other
))
and
(
self
.
val
==
other
.
val
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
^
hash
(
self
.
val
.
data
)
return
tensor
.
hashtype
(
self
)
^
hash
(
self
.
val
.
data
)
def
__str__
(
self
):
def
__str__
(
self
):
return
'
%
s{
%
s}'
%
(
self
.
__class__
.
__name__
,
self
.
val
)
return
'
%
s{
%
s}'
%
(
self
.
__class__
.
__name__
,
self
.
val
)
def
make_node
(
self
,
mat
):
def
make_node
(
self
,
mat
):
#check type of input
#check type of input
...
@@ -1486,7 +1621,8 @@ class Prepend_scalar_constant_to_each_row(gof.Op):
...
@@ -1486,7 +1621,8 @@ class Prepend_scalar_constant_to_each_row(gof.Op):
x
=
tensor
.
as_tensor_variable
(
mat
)
x
=
tensor
.
as_tensor_variable
(
mat
)
y
=
tensor
.
as_tensor_variable
(
self
.
val
)
y
=
tensor
.
as_tensor_variable
(
self
.
val
)
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
TypeError
(
"the value to prepend don't have the same type as the matrix"
)
TypeError
(
"the value to prepend don't have the same type as the matrix"
)
node
=
Apply
(
op
=
self
,
inputs
=
[
mat
],
outputs
=
[
tensor
.
matrix
()])
node
=
Apply
(
op
=
self
,
inputs
=
[
mat
],
outputs
=
[
tensor
.
matrix
()])
return
node
return
node
...
@@ -1494,31 +1630,34 @@ class Prepend_scalar_constant_to_each_row(gof.Op):
...
@@ -1494,31 +1630,34 @@ class Prepend_scalar_constant_to_each_row(gof.Op):
def
perform
(
self
,
node
,
inp
,
out
):
def
perform
(
self
,
node
,
inp
,
out
):
mat
,
=
inp
mat
,
=
inp
output
,
=
out
output
,
=
out
new_shape
=
(
mat
.
shape
[
0
],
mat
.
shape
[
1
]
+
1
)
new_shape
=
(
mat
.
shape
[
0
],
mat
.
shape
[
1
]
+
1
)
if
output
[
0
]
==
None
:
if
output
[
0
]
==
None
:
output
[
0
]
=
numpy
.
empty
(
new_shape
,
dtype
=
mat
.
dtype
)
output
[
0
]
=
numpy
.
empty
(
new_shape
,
dtype
=
mat
.
dtype
)
out
=
output
[
0
]
out
=
output
[
0
]
else
:
else
:
if
output
[
0
]
.
shape
!=
new_shape
:
if
output
[
0
]
.
shape
!=
new_shape
:
try
:
try
:
output
[
0
]
.
resize
(
new_shape
)
output
[
0
]
.
resize
(
new_shape
)
except
Exception
:
except
Exception
:
output
[
0
]
=
numpy
.
empty
(
new_shape
,
dtype
=
mat
.
dtype
)
output
[
0
]
=
numpy
.
empty
(
new_shape
,
dtype
=
mat
.
dtype
)
out
=
output
[
0
]
out
=
output
[
0
]
out
[:,
0
]
.
fill
(
self
.
val
.
data
)
out
[:,
0
]
.
fill
(
self
.
val
.
data
)
out
[:,
1
:]
=
mat
out
[:,
1
:]
=
mat
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
mat
,
=
inp
mat
,
=
inp
goutput
,
=
grads
goutput
,
=
grads
return
goutput
[:,
1
:]
return
goutput
[:,
1
:]
class
Prepend_scalar_to_each_row
(
gof
.
Op
):
class
Prepend_scalar_to_each_row
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
return
tensor
.
hashtype
(
self
)
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
...
@@ -1526,37 +1665,39 @@ class Prepend_scalar_to_each_row(gof.Op):
...
@@ -1526,37 +1665,39 @@ class Prepend_scalar_to_each_row(gof.Op):
#check type of input
#check type of input
if
isinstance
(
val
,
float
):
if
isinstance
(
val
,
float
):
val
=
scalar
.
constant
(
val
)
val
=
scalar
.
constant
(
val
)
if
not
isinstance
(
mat
,
gof
.
Variable
)
or
not
mat
.
type
==
tensor
.
matrix
()
.
type
:
if
(
not
isinstance
(
mat
,
gof
.
Variable
)
or
not
mat
.
type
==
tensor
.
matrix
()
.
type
):
raise
TypeError
(
"Expected a matrix as input"
)
raise
TypeError
(
"Expected a matrix as input"
)
x
=
tensor
.
as_tensor_variable
(
mat
)
x
=
tensor
.
as_tensor_variable
(
mat
)
y
=
tensor
.
as_tensor_variable
(
val
)
y
=
tensor
.
as_tensor_variable
(
val
)
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
TypeError
(
"the value to prepend don't have the same type as the matrix"
)
TypeError
(
"the value to prepend don't have the same type as the matrix"
)
node
=
Apply
(
op
=
self
,
inputs
=
[
val
,
mat
],
outputs
=
[
tensor
.
matrix
()])
node
=
Apply
(
op
=
self
,
inputs
=
[
val
,
mat
],
outputs
=
[
tensor
.
matrix
()])
return
node
return
node
def
perform
(
self
,
node
,
inp
,
out
):
def
perform
(
self
,
node
,
inp
,
out
):
val
,
mat
=
inp
val
,
mat
=
inp
output
,
=
out
output
,
=
out
new_shape
=
(
mat
.
shape
[
0
],
mat
.
shape
[
1
]
+
1
)
new_shape
=
(
mat
.
shape
[
0
],
mat
.
shape
[
1
]
+
1
)
if
output
[
0
]
==
None
:
if
output
[
0
]
==
None
:
output
[
0
]
=
numpy
.
empty
(
new_shape
,
dtype
=
mat
.
dtype
)
output
[
0
]
=
numpy
.
empty
(
new_shape
,
dtype
=
mat
.
dtype
)
out
=
output
[
0
]
out
=
output
[
0
]
else
:
else
:
if
output
[
0
]
.
shape
!=
new_shape
:
if
output
[
0
]
.
shape
!=
new_shape
:
try
:
try
:
output
[
0
]
.
resize
(
new_shape
)
output
[
0
]
.
resize
(
new_shape
)
except
Exception
:
except
Exception
:
output
[
0
]
=
numpy
.
empty
(
new_shape
,
dtype
=
mat
.
dtype
)
output
[
0
]
=
numpy
.
empty
(
new_shape
,
dtype
=
mat
.
dtype
)
out
=
output
[
0
]
out
=
output
[
0
]
out
[:,
0
]
.
fill
(
val
)
out
[:,
0
]
.
fill
(
val
)
out
[:,
1
:]
=
mat
out
[:,
1
:]
=
mat
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
val
,
mat
=
inp
val
,
mat
=
inp
goutput
,
=
grads
goutput
,
=
grads
return
goutput
[:,
0
],
goutput
[:,
1
:]
return
goutput
[:,
0
],
goutput
[:,
1
:]
prepend_scalar_to_each_row
=
Prepend_scalar_to_each_row
()
prepend_scalar_to_each_row
=
Prepend_scalar_to_each_row
()
prepend_0_to_each_row
=
Prepend_scalar_constant_to_each_row
(
0.
)
prepend_0_to_each_row
=
Prepend_scalar_constant_to_each_row
(
0.
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论