Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
13db39ca
提交
13db39ca
authored
4月 15, 2008
作者:
bergstrj@iro.umontreal.ca
浏览文件
操作
浏览文件
下载
差异文件
merged, changed gemm to use fortran blas instead of cblas
上级
476c9f66
c64b127c
显示空白字符变更
内嵌
并排
正在显示
13 个修改的文件
包含
991 行增加
和
102 行删除
+991
-102
_test_elemwise.py
_test_elemwise.py
+10
-0
_test_opt.py
_test_opt.py
+182
-0
_test_scalar.py
_test_scalar.py
+44
-0
_test_scalar_opt.py
_test_scalar_opt.py
+32
-0
_test_tensor.py
_test_tensor.py
+49
-5
blas.py
blas.py
+219
-4
elemwise.py
elemwise.py
+0
-3
cc.py
gof/cc.py
+1
-1
result.py
gof/result.py
+18
-10
opt.py
opt.py
+231
-22
scalar.py
scalar.py
+126
-3
scalar_opt.py
scalar_opt.py
+14
-0
tensor.py
tensor.py
+65
-54
没有找到文件。
_test_elemwise.py
浏览文件 @
13db39ca
...
@@ -120,6 +120,16 @@ class _test_Broadcast(unittest.TestCase):
...
@@ -120,6 +120,16 @@ class _test_Broadcast(unittest.TestCase):
f
(
xv
,
yv
)
f
(
xv
,
yv
)
assert
(
xv
==
yv
)
.
all
()
assert
(
xv
==
yv
)
.
all
()
def
test_weird_strides
(
self
):
x
=
modes
.
build
(
Tensor
(
'float64'
,
[
0
,
0
,
0
,
0
,
0
],
name
=
'x'
))
y
=
modes
.
build
(
Tensor
(
'float64'
,
[
0
,
0
,
0
,
0
,
0
],
name
=
'y'
))
e
=
Broadcast
(
Add
,
(
x
,
y
))
.
out
f
=
gof
.
CLinker
(
env
([
x
,
y
],
[
e
]))
.
make_function
(
inplace
=
False
)
xv
=
numpy
.
random
.
rand
(
2
,
2
,
2
,
2
,
2
)
yv
=
numpy
.
random
.
rand
(
2
,
2
,
2
,
2
,
2
)
.
transpose
(
4
,
0
,
3
,
1
,
2
)
zv
=
xv
+
yv
assert
(
f
(
xv
,
yv
)
==
zv
)
.
all
()
class
_test_CAReduce
(
unittest
.
TestCase
):
class
_test_CAReduce
(
unittest
.
TestCase
):
...
...
_test_opt.py
0 → 100644
浏览文件 @
13db39ca
import
unittest
import
gof
from
opt
import
*
import
tensor
from
tensor
import
Tensor
from
gof
import
Env
from
elemwise
import
DimShuffle
import
numpy
import
scalar_opt
def
inputs
(
xbc
=
(
0
,
0
),
ybc
=
(
0
,
0
),
zbc
=
(
0
,
0
)):
x
=
Tensor
(
broadcastable
=
xbc
,
dtype
=
'float64'
,
name
=
'x'
)
y
=
Tensor
(
broadcastable
=
ybc
,
dtype
=
'float64'
,
name
=
'y'
)
z
=
Tensor
(
broadcastable
=
zbc
,
dtype
=
'float64'
,
name
=
'z'
)
return
x
,
y
,
z
ds
=
gof
.
op
.
constructor
(
DimShuffle
)
class
_test_inplace_opt
(
unittest
.
TestCase
):
def
test_straightforward
(
self
):
x
,
y
,
z
=
inputs
()
e
=
x
+
y
+
z
g
=
Env
([
x
,
y
],
[
e
])
assert
str
(
g
)
==
"[Broadcast{Add}(Broadcast{Add}(x, y), z)]"
inplace_optimizer
.
optimize
(
g
)
assert
str
(
g
)
==
"[Broadcast{Add}{0: 0}(Broadcast{Add}{0: 0}(x, y), z)]"
def
test_multiple_uses
(
self
):
x
,
y
,
z
=
inputs
()
e0
=
x
+
y
e1
=
x
*
y
g
=
Env
([
x
,
y
],
[
e0
,
e1
])
assert
str
(
g
)
==
"[Broadcast{Add}(x, y), Broadcast{Mul}(x, y)]"
inplace_optimizer
.
optimize
(
g
)
assert
str
(
g
)
==
"[Broadcast{Add}{0: 0}(x, y), Broadcast{Mul}(x, y)]"
\
or
str
(
g
)
==
"[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]"
def
test_user_inplace
(
self
):
x
,
y
,
z
=
inputs
()
e0
=
x
+
y
e1
=
tensor
.
mul_inplace
(
x
,
y
)
g
=
Env
([
x
,
y
],
[
e0
,
e1
])
assert
str
(
g
)
==
"[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]"
inplace_optimizer
.
optimize
(
g
)
assert
str
(
g
)
==
"[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]"
class
_test_dimshuffle_lift
(
unittest
.
TestCase
):
def
test_double_transpose
(
self
):
x
,
y
,
z
=
inputs
()
e
=
ds
(
ds
(
x
,
(
1
,
0
)),
(
1
,
0
))
g
=
Env
([
x
],
[
e
])
assert
str
(
g
)
==
"[DimShuffle{10}(DimShuffle{10}(x))]"
lift_dimshuffle
.
optimize
(
g
)
assert
str
(
g
)
==
"[x]"
def
test_merge2
(
self
):
x
,
y
,
z
=
inputs
()
e
=
ds
(
ds
(
x
,
(
1
,
'x'
,
0
)),
(
2
,
0
,
'x'
,
1
))
g
=
Env
([
x
],
[
e
])
self
.
failUnless
(
str
(
g
)
==
"[DimShuffle{20x1}(DimShuffle{1x0}(x))]"
,
str
(
g
))
lift_dimshuffle
.
optimize
(
g
)
self
.
failUnless
(
str
(
g
)
==
"[DimShuffle{01xx}(x)]"
,
str
(
g
))
def
test_elim3
(
self
):
x
,
y
,
z
=
inputs
()
e
=
ds
(
ds
(
ds
(
x
,
(
0
,
'x'
,
1
)),
(
2
,
0
,
'x'
,
1
)),
(
1
,
0
))
g
=
Env
([
x
],
[
e
])
self
.
failUnless
(
str
(
g
)
==
"[DimShuffle{10}(DimShuffle{20x1}(DimShuffle{0x1}(x)))]"
,
str
(
g
))
lift_dimshuffle
.
optimize
(
g
)
self
.
failUnless
(
str
(
g
)
==
"[x]"
,
str
(
g
))
def
test_lift
(
self
):
x
,
y
,
z
=
inputs
([
0
]
*
1
,
[
0
]
*
2
,
[
0
]
*
3
)
e
=
x
+
y
+
z
g
=
Env
([
x
,
y
,
z
],
[
e
])
self
.
failUnless
(
str
(
g
)
==
"[Broadcast{Add}(DimShuffle{x01}(Broadcast{Add}(DimShuffle{x0}(x), y)), z)]"
,
str
(
g
))
lift_dimshuffle
.
optimize
(
g
)
self
.
failUnless
(
str
(
g
)
==
"[Broadcast{Add}(Broadcast{Add}(DimShuffle{xx0}(x), DimShuffle{x01}(y)), z)]"
,
str
(
g
))
class
_test_cliques
(
unittest
.
TestCase
):
def
test_straightforward
(
self
):
x
,
y
,
z
=
inputs
()
m
=
y
*
z
d
=
tensor
.
dot
(
x
,
m
)
d
.
name
=
'd'
e
=
x
+
y
+
d
g
=
Env
([
x
,
y
,
z
],
[
e
])
cliques
=
find_cliques
(
g
)
assert
len
(
cliques
)
==
2
(
i1
,
o1
),
(
i2
,
o2
)
=
cliques
assert
str
(
Env
(
i1
,
o1
))
==
"[Broadcast{Add}(Broadcast{Add}(x, y), d)]"
assert
str
(
Env
(
i2
,
o2
))
==
"[Broadcast{Mul}(y, z)]"
# print g
# for i, o in find_cliques(g):
# print "-->", Env(i, [o])
def
test_broadcasting
(
self
):
x
,
y
,
z
=
inputs
([
0
]
*
1
,
[
0
]
*
2
,
[
0
]
*
3
)
e
=
x
+
y
+
z
g
=
Env
([
x
,
y
,
z
],
[
e
])
lift_dimshuffle
.
optimize
(
g
)
assert
len
(
find_cliques
(
g
,
through_broadcast
=
True
))
==
1
assert
len
(
find_cliques
(
g
,
through_broadcast
=
False
))
==
2
# print g
# for i, o in find_cliques(g, True):
# print "-->", Env(i, [o])
# class _test_clique_opt(unittest.TestCase):
# def test_straightforward(self):
# x, y, z = inputs()
# e = x ** 2.0 #x * x
# g = Env([x], [e])
# gof.ConstantFinder().optimize(g)
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = scalar_opt.opt2,
# make_composite = False)
# print g
# opt.optimize(g)
# print g
# def test_inplace(self):
# x, y, z = inputs()
# #e = tensor.add_inplace(x, y + z)
# e = x + tensor.add_inplace(y, z)
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
# def test_straightforward(self):
# x, y, z = inputs()
# e = x + y + z
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
# def test_straightforward2(self):
# x, y, z = inputs()
# m = y * z
# d = tensor.dot(x, m)
# d.name = 'd'
# e = x + y + d
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
if
__name__
==
'__main__'
:
unittest
.
main
()
_test_scalar.py
浏览文件 @
13db39ca
...
@@ -27,6 +27,50 @@ class _test_ScalarOps(unittest.TestCase):
...
@@ -27,6 +27,50 @@ class _test_ScalarOps(unittest.TestCase):
assert
fn
(
1.0
,
2.0
)
==
1.5
assert
fn
(
1.0
,
2.0
)
==
1.5
class
_test_composite
(
unittest
.
TestCase
):
def
test_straightforward
(
self
):
x
,
y
,
z
=
inputs
()
e
=
mul
(
add
(
x
,
y
),
div
(
x
,
y
))
C
=
composite
([
x
,
y
],
[
e
])
c
=
C
(
x
,
y
)
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c
.
perform
()
assert
c
.
outputs
[
0
]
.
data
==
1.5
g
=
env
([
x
,
y
],
[
c
.
out
])
fn
=
gof
.
DualLinker
(
g
)
.
make_function
()
assert
fn
(
1.0
,
2.0
)
==
1.5
def
test_with_constants
(
self
):
x
,
y
,
z
=
inputs
()
e
=
mul
(
add
(
70.0
,
y
),
div
(
x
,
y
))
C
=
composite
([
x
,
y
],
[
e
])
c
=
C
(
x
,
y
)
assert
"70.0"
in
c
.
c_code
([
'x'
,
'y'
],
[
'z'
],
dict
(
id
=
0
))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c
.
perform
()
assert
c
.
outputs
[
0
]
.
data
==
36.0
g
=
env
([
x
,
y
],
[
c
.
out
])
fn
=
gof
.
DualLinker
(
g
)
.
make_function
()
assert
fn
(
1.0
,
2.0
)
==
36.0
def
test_many_outputs
(
self
):
x
,
y
,
z
=
inputs
()
e0
=
x
+
y
+
z
e1
=
x
+
y
*
z
e2
=
x
/
y
C
=
composite
([
x
,
y
,
z
],
[
e0
,
e1
,
e2
])
c
=
C
(
x
,
y
,
z
)
# print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0))
c
.
perform
()
assert
c
.
outputs
[
0
]
.
data
==
6.0
assert
c
.
outputs
[
1
]
.
data
==
7.0
assert
c
.
outputs
[
2
]
.
data
==
0.5
g
=
env
([
x
,
y
],
c
.
outputs
)
fn
=
gof
.
DualLinker
(
g
)
.
make_function
()
assert
fn
(
1.0
,
2.0
)
==
[
6.0
,
7.0
,
0.5
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
...
...
_test_scalar_opt.py
0 → 100644
浏览文件 @
13db39ca
import
unittest
from
gof
import
Result
,
Op
,
Env
,
modes
import
gof
from
scalar
import
*
from
scalar_opt
import
*
def
inputs
():
x
=
Scalar
(
'float64'
,
name
=
'x'
)
y
=
Scalar
(
'float64'
,
name
=
'y'
)
z
=
Scalar
(
'float64'
,
name
=
'z'
)
return
x
,
y
,
z
class
_test_opts
(
unittest
.
TestCase
):
def
test_pow_to_sqr
(
self
):
x
,
y
,
z
=
inputs
()
e
=
x
**
2.0
g
=
Env
([
x
],
[
e
])
assert
str
(
g
)
==
"[Pow(x, 2.0)]"
gof
.
ConstantFinder
()
.
optimize
(
g
)
opt2
.
optimize
(
g
)
assert
str
(
g
)
==
"[Sqr(x)]"
if
__name__
==
'__main__'
:
unittest
.
main
()
_test_tensor.py
浏览文件 @
13db39ca
...
@@ -990,21 +990,21 @@ class t_gemm(unittest.TestCase):
...
@@ -990,21 +990,21 @@ class t_gemm(unittest.TestCase):
def
cmp
(
self
,
z
,
a
,
x
,
y
,
b
):
def
cmp
(
self
,
z
,
a
,
x
,
y
,
b
):
def
cmp_linker
(
z
,
a
,
x
,
y
,
b
,
l
):
def
cmp_linker
(
z
,
a
,
x
,
y
,
b
,
l
):
z
,
a
,
x
,
y
,
b
=
[
numpy
.
asarray
(
p
)
for
p
in
z
,
a
,
x
,
y
,
b
]
z
,
a
,
x
,
y
,
b
=
[
numpy
.
asarray
(
p
)
for
p
in
z
,
a
,
x
,
y
,
b
]
cz
=
z
.
copy
()
z_orig
=
z
.
copy
()
tz
,
ta
,
tx
,
ty
,
tb
=
[
astensor
(
p
)
for
p
in
z
,
a
,
x
,
y
,
b
]
tz
,
ta
,
tx
,
ty
,
tb
=
[
astensor
(
p
)
for
p
in
z
,
a
,
x
,
y
,
b
]
f
=
Function
([
tz
,
ta
,
tx
,
ty
,
tb
],
[
gemm
(
tz
,
ta
,
tx
,
ty
,
tb
)],
linker_cls
=
l
)
f
=
Function
([
tz
,
ta
,
tx
,
ty
,
tb
],
[
gemm
(
tz
,
ta
,
tx
,
ty
,
tb
)],
linker_cls
=
l
)
new_z
=
f
(
z
,
a
,
x
,
y
,
b
)
new_z
=
f
(
z
,
a
,
x
,
y
,
b
)
_z
=
self
.
_gemm
(
cz
,
a
,
x
,
y
,
b
)
z_after
=
self
.
_gemm
(
z_orig
,
a
,
x
,
y
,
b
)
self
.
failUnless
(
z
is
new_z
)
self
.
failUnless
(
z
is
new_z
)
#print
cz, _z, z, type(cz), type(_z
), type(z)
#print
z_orig, z_after, z, type(z_orig), type(z_after
), type(z)
#_approx_eq.debug = 1
#_approx_eq.debug = 1
self
.
failUnless
(
_approx_eq
(
_z
,
z
))
self
.
failUnless
(
_approx_eq
(
z_after
,
z
))
if
a
==
0.0
and
b
==
1.0
:
if
a
==
0.0
and
b
==
1.0
:
return
return
else
:
else
:
self
.
failIf
(
numpy
.
all
(
cz
==
z
))
self
.
failIf
(
numpy
.
all
(
z_orig
==
z
))
cmp_linker
(
copy
(
z
),
a
,
x
,
y
,
b
,
gof
.
cc
.
OpWiseCLinker
)
cmp_linker
(
copy
(
z
),
a
,
x
,
y
,
b
,
gof
.
cc
.
OpWiseCLinker
)
#cmp_linker(copy(z), a, x, y, b, gof.cc.CLinker)
#cmp_linker(copy(z), a, x, y, b, gof.cc.CLinker)
...
@@ -1101,5 +1101,49 @@ class t_gemm(unittest.TestCase):
...
@@ -1101,5 +1101,49 @@ class t_gemm(unittest.TestCase):
eval_outputs
([
gemm
(
Z
,
1.0
,
A
,
A
,
1.0
)])
eval_outputs
([
gemm
(
Z
,
1.0
,
A
,
A
,
1.0
)])
eval_outputs
([
gemm
(
Z
,
1.0
,
A
,
A
.
T
,
1.0
)])
eval_outputs
([
gemm
(
Z
,
1.0
,
A
,
A
.
T
,
1.0
)])
def
test_transposes
(
self
):
# three square matrices which are not contiguous
A
=
self
.
rand
(
4
,
5
)[:,:
4
]
B
=
self
.
rand
(
4
,
5
)[:,:
4
]
C
=
self
.
rand
(
4
,
5
)[:,:
4
]
def
t
(
z
,
x
,
y
,
a
=
1.0
,
b
=
0.0
,
l
=
gof
.
cc
.
OpWiseCLinker
):
z
,
a
,
x
,
y
,
b
=
[
numpy
.
asarray
(
p
)
for
p
in
z
,
a
,
x
,
y
,
b
]
z_orig
=
z
.
copy
()
z_after
=
self
.
_gemm
(
z
,
a
,
x
,
y
,
b
)
tz
,
ta
,
tx
,
ty
,
tb
=
[
astensor
(
p
)
for
p
in
z
,
a
,
x
,
y
,
b
]
f
=
Function
([
tz
,
ta
,
tx
,
ty
,
tb
],
[
gemm
(
tz
,
ta
,
tx
,
ty
,
tb
)],
linker_cls
=
l
)
f
(
z
,
a
,
x
,
y
,
b
)
self
.
failUnless
(
_approx_eq
(
z_after
,
z
),
(
z_orig
,
z_after
,
z
))
f
(
z
.
T
,
a
,
y
.
T
,
x
.
T
,
b
)
self
.
failUnless
(
_approx_eq
(
z_after
,
z
))
t
(
C
,
A
,
B
)
t
(
C
.
T
,
A
,
B
)
t
(
C
,
A
.
T
,
B
)
t
(
C
,
A
,
B
.
T
)
t
(
C
.
T
,
A
.
T
,
B
)
t
(
C
,
A
.
T
,
B
.
T
)
t
(
C
.
T
,
A
,
B
.
T
)
t
(
C
.
T
,
A
.
T
,
B
.
T
)
t
(
C
,
A
[:,:
2
],
B
[:
2
,
:])
t
(
C
.
T
,
A
[:,:
2
],
B
[:
2
,
:])
t
(
C
,
A
[:
2
,:]
.
T
,
B
[:
2
,
:])
t
(
C
.
T
,
A
[:
2
,:]
.
T
,
B
[:
2
,
:])
t
(
C
,
A
[:
2
,:]
.
T
,
B
[:,
:
2
]
.
T
)
t
(
C
.
T
,
A
[:
2
,:]
.
T
,
B
[:,
:
2
]
.
T
)
try
:
t
(
C
.
T
,
A
[:
2
,:],
B
[:,
:
2
]
.
T
)
except
ValueError
,
e
:
if
e
[
0
]
.
find
(
'aligned'
)
>=
0
:
return
self
.
fail
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
blas.py
浏览文件 @
13db39ca
...
@@ -11,6 +11,7 @@ fine-grained motifs of iadd, isub, scale, and dot.
...
@@ -11,6 +11,7 @@ fine-grained motifs of iadd, isub, scale, and dot.
"""
"""
def
cblas_header_text
():
def
cblas_header_text
():
"""C header for the cblas interface"""
return
"""
return
"""
//#include <stddef.h>
//#include <stddef.h>
...
@@ -589,6 +590,210 @@ def cblas_header_text():
...
@@ -589,6 +590,210 @@ def cblas_header_text():
__END_DECLS
__END_DECLS
"""
"""
def
blas_proto
():
"""C header for the fortran blas interface"""
return
"""
extern "C"
{
void xerbla_(char*, void *);
/***********/
/* Level 1 */
/***********/
/* Single Precision */
void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *);
void srotg_(float *,float *,float *,float *);
void srotm_( const int*, float *, const int*, float *, const int*, const float *);
void srotmg_(float *,float *,float *,const float *, float *);
void sswap_( const int*, float *, const int*, float *, const int*);
void scopy_( const int*, const float *, const int*, float *, const int*);
void saxpy_( const int*, const float *, const float *, const int*, float *, const int*);
void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *);
void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *);
void sscal_( const int*, const float *, float *, const int*);
void snrm2_sub_( const int*, const float *, const int*, float *);
void sasum_sub_( const int*, const float *, const int*, float *);
void isamax_sub_( const int*, const float * , const int*, const int*);
/* Double Precision */
void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *);
void drotg_(double *,double *,double *,double *);
void drotm_( const int*, double *, const int*, double *, const int*, const double *);
void drotmg_(double *,double *,double *,const double *, double *);
void dswap_( const int*, double *, const int*, double *, const int*);
void dcopy_( const int*, const double *, const int*, double *, const int*);
void daxpy_( const int*, const double *, const double *, const int*, double *, const int*);
void dswap_( const int*, double *, const int*, double *, const int*);
void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *);
void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *);
void dscal_( const int*, const double *, double *, const int*);
void dnrm2_sub_( const int*, const double *, const int*, double *);
void dasum_sub_( const int*, const double *, const int*, double *);
void idamax_sub_( const int*, const double * , const int*, const int*);
/* Single Complex Precision */
void cswap_( const int*, void *, const int*, void *, const int*);
void ccopy_( const int*, const void *, const int*, void *, const int*);
void caxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void cswap_( const int*, void *, const int*, void *, const int*);
void cdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cscal_( const int*, const void *, void *, const int*);
void icamax_sub_( const int*, const void *, const int*, const int*);
void csscal_( const int*, const float *, void *, const int*);
void scnrm2_sub_( const int*, const void *, const int*, float *);
void scasum_sub_( const int*, const void *, const int*, float *);
/* Double Complex Precision */
void zswap_( const int*, void *, const int*, void *, const int*);
void zcopy_( const int*, const void *, const int*, void *, const int*);
void zaxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void zswap_( const int*, void *, const int*, void *, const int*);
void zdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdscal_( const int*, const double *, void *, const int*);
void zscal_( const int*, const void *, void *, const int*);
void dznrm2_sub_( const int*, const void *, const int*, double *);
void dzasum_sub_( const int*, const void *, const int*, double *);
void izamax_sub_( const int*, const void *, const int*, const int*);
/***********/
/* Level 2 */
/***********/
/* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sspmv_(char*, const int*, const float *, const float *, const float *, const int*, const float *, float *, const int*);
void strmv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbmv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void strsv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbsv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void stpmv_( char*, char*, char*, const int*, const float *, float *, const int*);
void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*);
void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*);
void sspr_(char*, const int*, const float *, const float *, const int*, float *);
void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *);
void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dgbmv_(char*, const int*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymv_(char*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsbmv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dspmv_(char*, const int*, const double *, const double *, const double *, const int*, const double *, double *, const int*);
void dtrmv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbmv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtrsv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbsv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtpmv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*);
void dspr_(char*, const int*, const double *, const double *, const int*, double *);
void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *);
void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void cgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ctrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ctrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void cgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cher_(char*, const int*, const float *, const void *, const int*, void *, const int*);
void cher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void chpr_(char*, const int*, const float *, const void *, const int*, void *);
void chpr2_(char*, const int*, const float *, const void *, const int*, const void *, const int*, void *);
/* Double Complex Precision */
void zgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ztrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ztrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void zgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zher_(char*, const int*, const double *, const void *, const int*, void *, const int*);
void zher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zhpr_(char*, const int*, const double *, const void *, const int*, void *);
void zhpr2_(char*, const int*, const double *, const void *, const int*, const void *, const int*, void *);
/***********/
/* Level 3 */
/***********/
/* Single Precision */
void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void strmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void strsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dtrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void dtrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void chemm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void cherk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void csyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void cher2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ctrmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void ctrsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Complex Precision */
void zgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zhemm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zherk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zher2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void ztrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void ztrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
}
"""
def
_constant
(
f
):
def
_constant
(
f
):
"""Return a function that always returns its first call value
"""Return a function that always returns its first call value
"""
"""
...
@@ -603,12 +808,22 @@ def ldflags():
...
@@ -603,12 +808,22 @@ def ldflags():
"""Return a list of libraries against which an Op's object file should be
"""Return a list of libraries against which an Op's object file should be
linked to benefit from a BLAS implementation.
linked to benefit from a BLAS implementation.
Default: ['
cblas','blas'], but environment variable OMEGA
_BLAS_LDFLAGS overrides this.
Default: ['
blas'], but environment variable THEANO
_BLAS_LDFLAGS overrides this.
"""
"""
if
os
.
getenv
(
'OMEGA_BLAS_LDFLAGS'
):
if
os
.
getenv
(
'THEANO_BLAS_LDFLAGS'
):
return
os
.
getenv
(
'OMEGA_BLAS_LDFLAGS'
)
.
split
()
tokens
=
os
.
getenv
(
'THEANO_BLAS_LDFLAGS'
)
.
split
()
for
t
in
tokens
:
try
:
t0
,
t1
,
t2
=
t
[
0
:
3
]
assert
t0
==
'-'
except
e
:
raise
ValueError
(
'invalid token in THEANO_BLAS_LDFLAGS'
,
t
)
if
t1
==
'L'
:
raise
ValueError
(
'library dir not allowed in THEANO_BLAS_LDFLAGS'
,
t
)
rval
=
[
token
[
2
:]
for
token
in
tokens
]
return
rval
else
:
else
:
return
[
'
cblas'
,
'
blas'
]
return
[
'blas'
]
def
gemm_code
(
check_ab
,
a_init
,
b_init
):
def
gemm_code
(
check_ab
,
a_init
,
b_init
):
mod
=
'
%
'
mod
=
'
%
'
...
...
elemwise.py
浏览文件 @
13db39ca
...
@@ -397,9 +397,6 @@ class CAReduce(Op):
...
@@ -397,9 +397,6 @@ class CAReduce(Op):
if
dimensions_to_reduce
is
None
:
if
dimensions_to_reduce
is
None
:
dimensions_to_reduce
=
range
(
len
(
inputs
[
0
]
.
broadcastable
))
dimensions_to_reduce
=
range
(
len
(
inputs
[
0
]
.
broadcastable
))
self
.
nin
=
1
self
.
nout
=
1
self
.
inputs
=
inputs
self
.
inputs
=
inputs
self
.
outputs
=
[
Tensor
(
dtype
=
inputs
[
0
]
.
dtype
,
self
.
outputs
=
[
Tensor
(
dtype
=
inputs
[
0
]
.
dtype
,
broadcastable
=
[
x
for
i
,
x
in
enumerate
(
inputs
[
0
]
.
broadcastable
)
if
i
not
in
dimensions_to_reduce
])]
broadcastable
=
[
x
for
i
,
x
in
enumerate
(
inputs
[
0
]
.
broadcastable
)
if
i
not
in
dimensions_to_reduce
])]
...
...
gof/cc.py
浏览文件 @
13db39ca
...
@@ -409,7 +409,7 @@ class CLinker(Linker):
...
@@ -409,7 +409,7 @@ class CLinker(Linker):
elif
result
in
self
.
orphans
:
elif
result
in
self
.
orphans
:
self
.
orphans
.
remove
(
result
)
self
.
orphans
.
remove
(
result
)
continue
continue
except
AbstractFunctionError
:
except
(
AbstractFunctionError
,
NotImplementedError
)
:
pass
pass
# policy = [[what to declare in the struct, what to do at construction, what to do at destruction],
# policy = [[what to declare in the struct, what to do at construction, what to do at destruction],
# [what to declare in each run, what to do at the beginning of each run, what to do at the end of each run]]
# [what to declare in each run, what to do at the beginning of each run, what to do at the end of each run]]
...
...
gof/result.py
浏览文件 @
13db39ca
...
@@ -35,22 +35,23 @@ class Computed : """Memory has been allocated, contents are the owner's output."
...
@@ -35,22 +35,23 @@ class Computed : """Memory has been allocated, contents are the owner's output."
############################
############################
class
Result
(
object
):
class
Result
(
object
):
"""Base class for storing L{Op} inputs and outputs
"""
Base class for storing L{Op} inputs and outputs
Attributes:
Attributes:
_role - None or (owner, index) #or BrokenLink
-
_role - None or (owner, index) #or BrokenLink
_data - anything
-
_data - anything
state - one of (Empty, Allocated, Computed)
-
state - one of (Empty, Allocated, Computed)
name - string
-
name - string
Properties:
Properties:
role - (rw)
-
role - (rw)
owner - (ro)
-
owner - (ro)
index - (ro)
-
index - (ro)
data - (rw) : calls data_filter when setting
-
data - (rw) : calls data_filter when setting
Abstract Methods:
Abstract Methods:
data_filter
-
data_filter
"""
"""
__slots__
=
[
'_role'
,
'_data'
,
'state'
,
'_name'
,
'_hash_id'
]
__slots__
=
[
'_role'
,
'_data'
,
'state'
,
'_name'
,
'_hash_id'
]
...
@@ -241,6 +242,13 @@ class Result(object):
...
@@ -241,6 +242,13 @@ class Result(object):
def
c_libraries
(
self
):
def
c_libraries
(
self
):
"""
"""
Return a list of libraries to link against to manipulate this L{Result}.
Return a list of libraries to link against to manipulate this L{Result}.
For example: return ['gsl', 'gslcblas', 'm', 'fftw3', 'g2c'].
The compiler will search the directories specified by the environment
variable LD_LIBRARY_PATH. No option is provided for an Op to provide an
extra library directory because this would change the linking path for
other Ops in a potentially disasterous way.
"""
"""
raise
AbstractFunctionError
()
raise
AbstractFunctionError
()
...
...
opt.py
浏览文件 @
13db39ca
from
gof
import
opt
from
gof
import
opt
,
Env
from
elemwise
import
Broadcast
import
gof
from
elemwise
import
Broadcast
,
DimShuffle
from
gof.python25
import
any
,
all
import
scalar
class
InplaceOptimizer
(
opt
.
OpSpecificOptimizer
):
class
InplaceOptimizer
(
opt
.
OpSpecificOptimizer
):
...
@@ -26,30 +29,236 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
...
@@ -26,30 +29,236 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
inplace_optimizer
=
InplaceOptimizer
()
inplace_optimizer
=
InplaceOptimizer
()
# class ElemwisePatternOptimizer(opt.Optimizer):
# def __init__(self, scalar_opt):
class
DimShuffleLifter
(
opt
.
Optimizer
):
# self.
"""
"Lifts" DimShuffle through Broadcast operations and merges
consecutive DimShuffles. Basically, applies the following
# def find_elemwise_cliques(env, cross_broadcast = False):
transformations on the whole graph:
DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x)
After this transform, clusters of Broadcast operations are
void of DimShuffle operations.
# def synchronize(env1, env2, equiv, transform):
"""
def
apply
(
self
,
env
):
seen
=
set
()
def
merge
(
ord1
,
ord2
):
return
[
x
==
'x'
and
'x'
or
ord1
[
x
]
for
x
in
ord2
]
def
lift
(
r
):
if
r
in
seen
:
return
seen
.
add
(
r
)
op
=
r
.
owner
if
op
is
None
\
or
op
in
env
.
inputs
\
or
op
in
env
.
orphans
():
return
if
isinstance
(
op
,
DimShuffle
):
in_op
=
op
.
inputs
[
0
]
.
owner
if
isinstance
(
in_op
,
DimShuffle
):
new_order
=
[
x
==
'x'
and
'x'
or
in_op
.
new_order
[
x
]
for
x
in
op
.
new_order
]
if
new_order
==
range
(
len
(
new_order
)):
repl
=
in_op
.
inputs
[
0
]
else
:
repl
=
DimShuffle
(
in_op
.
inputs
[
0
],
new_order
)
.
out
env
.
replace
(
r
,
repl
)
lift
(
repl
)
return
elif
isinstance
(
in_op
,
Broadcast
):
repl
=
Broadcast
(
in_op
.
scalar_opclass
,
[
DimShuffle
(
input
,
op
.
new_order
)
.
out
for
input
in
in_op
.
inputs
],
in_op
.
inplace_pattern
)
.
out
env
.
replace
(
r
,
repl
)
r
=
repl
op
=
r
.
owner
for
next_r
in
op
.
inputs
:
lift
(
next_r
)
for
output
in
env
.
outputs
:
lift
(
output
)
lift_dimshuffle
=
DimShuffleLifter
()
def
find_cliques
(
env
,
through_broadcast
=
False
):
def
seek_from
(
r
):
op
=
r
.
owner
if
r
in
env
.
inputs
\
or
r
in
env
.
orphans
()
\
or
op
is
None
\
or
not
isinstance
(
op
,
Broadcast
)
\
or
len
(
op
.
outputs
)
>
1
:
# todo: handle multiple-output broadcast ops
# (needs to update the clique's outputs)
return
None
ret
=
set
()
if
not
through_broadcast
:
if
any
(
any
(
bc
)
and
not
all
(
bc
)
for
bc
in
zip
(
*
[
input
.
broadcastable
for
input
in
op
.
inputs
])):
ret
.
update
(
op
.
inputs
)
return
ret
for
input
in
op
.
inputs
:
res
=
seek_from
(
input
)
if
res
is
None
:
ret
.
add
(
input
)
else
:
ret
.
update
(
res
)
return
ret
cliques
=
[]
def
find_cliques_helper
(
r
):
if
r
in
env
.
inputs
or
r
in
env
.
orphans
():
return
clique_inputs
=
seek_from
(
r
)
if
clique_inputs
is
None
:
op
=
r
.
owner
if
op
is
not
None
:
for
input
in
op
.
inputs
:
find_cliques_helper
(
input
)
else
:
cliques
.
append
((
clique_inputs
,
[
r
]))
for
input
in
clique_inputs
:
find_cliques_helper
(
input
)
for
output
in
env
.
outputs
:
find_cliques_helper
(
output
)
# todo: merge the cliques if possible
return
cliques
class
CliqueOptimizer
(
opt
.
Optimizer
):
def
__init__
(
self
,
through_broadcast
=
False
,
scalar_optimizer
=
None
,
make_composite
=
False
):
self
.
through_broadcast
=
through_broadcast
self
.
scalar_optimizer
=
scalar_optimizer
self
.
make_composite
=
make_composite
def
apply
(
self
,
env
):
if
self
.
scalar_optimizer
is
None
and
not
self
.
make_composite
:
# there's nothing to do with the cliques...
return
cliques
=
find_cliques
(
env
,
self
.
through_broadcast
)
opt
=
self
.
scalar_optimizer
def
build_scalar_clique
(
r
,
env
,
equiv
):
if
r
in
equiv
:
return
equiv
[
r
]
op
=
r
.
owner
if
r
in
env
.
inputs
or
r
in
env
.
orphans
():
s
=
scalar
.
Scalar
(
dtype
=
r
.
dtype
)
_r
=
r
if
isinstance
(
r
.
owner
,
DimShuffle
)
and
all
(
x
==
'x'
for
x
in
r
.
owner
.
new_order
):
_r
=
r
.
owner
.
inputs
[
0
]
if
(
getattr
(
r
,
'constant'
,
False
)
or
getattr
(
_r
,
'constant'
,
False
))
\
and
_r
.
broadcastable
==
():
s
.
data
=
_r
.
data
s
.
constant
=
True
equiv
[
r
]
=
s
return
s
s_op
=
op
.
scalar_opclass
(
*
[
build_scalar_clique
(
input
,
env
,
equiv
)
for
input
in
op
.
inputs
])
equiv
[
op
]
=
s_op
for
output
,
s_output
in
zip
(
op
.
outputs
,
s_op
.
outputs
):
equiv
[
output
]
=
s_output
return
equiv
[
r
]
for
c_in
,
c_out
in
cliques
:
equiv
=
dict
()
g
=
Env
(
c_in
,
c_out
)
for
output
in
c_out
:
build_scalar_clique
(
output
,
g
,
equiv
)
s_g
=
Env
([
equiv
[
r
]
for
r
in
g
.
inputs
],
[
equiv
[
r
]
for
r
in
g
.
outputs
])
if
opt
is
not
None
:
equiv2
=
dict
()
for
k
,
v
in
equiv
.
items
():
equiv2
[
v
]
=
k
def
transform
(
op
,
equiv
):
return
Broadcast
(
op
.
__class__
,
[
equiv
[
input
]
for
input
in
op
.
inputs
])
s_g
.
add_feature
(
sync_to
(
env
,
equiv2
,
transform
))
opt
.
optimize
(
s_g
)
if
self
.
make_composite
:
def
follow_inplace
(
r
):
op
=
r
.
owner
if
op
is
None
or
r
in
g
.
inputs
or
r
in
g
.
orphans
():
return
None
assert
isinstance
(
op
,
Broadcast
)
destroyed
=
op
.
destroy_map
()
.
get
(
r
,
None
)
if
destroyed
is
None
:
return
None
else
:
r2
=
destroyed
[
0
]
ret
=
follow_inplace
(
r2
)
if
ret
is
None
:
return
r2
else
:
return
ret
inplace_pattern
=
{}
for
i
,
output
in
enumerate
(
g
.
outputs
):
destroyed
=
follow_inplace
(
output
)
if
destroyed
is
not
None
and
destroyed
in
g
.
inputs
:
inplace_pattern
[
i
]
=
g
.
inputs
.
index
(
destroyed
)
C
=
scalar
.
composite
(
s_g
.
inputs
,
s_g
.
outputs
)
ec
=
Broadcast
(
C
,
g
.
inputs
,
inplace_pattern
=
inplace_pattern
)
env
.
replace_all
(
dict
((
o
,
eco
)
for
o
,
eco
in
zip
(
c_out
,
ec
.
outputs
)))
def
sync_to
(
target
,
equiv
,
transform
):
class
Synchronize
(
gof
.
Listener
,
gof
.
Constraint
):
def
__init__
(
self
,
source
):
self
.
source
=
source
self
.
target
=
target
self
.
equiv
=
equiv
self
.
transform
=
transform
self
.
inconsistencies
=
[]
def
on_import
(
self
,
op1
):
if
op1
not
in
self
.
equiv
:
op2
=
self
.
transform
(
op1
,
self
.
equiv
)
self
.
equiv
[
op1
]
=
op2
for
o1
,
o2
in
zip
(
op1
.
outputs
,
op2
.
outputs
):
self
.
equiv
[
o1
]
=
o2
def
on_prune
(
self
,
op1
):
if
op1
in
self
.
equiv
:
op2
=
self
.
equiv
[
op1
]
del
self
.
equiv
[
op1
]
for
o1
,
o2
in
zip
(
op1
.
outputs
,
op2
.
outputs
):
del
self
.
equiv
[
o1
]
def
on_rewire
(
self
,
clients1
,
r1
,
new_r1
):
if
(
new_r1
,
r1
)
in
self
.
inconsistencies
:
self
.
inconsistencies
.
remove
((
new_r1
,
r1
))
return
if
not
self
.
source
.
clients
(
r1
):
try
:
target
.
replace
(
self
.
equiv
[
r1
],
self
.
equiv
[
new_r1
])
except
:
self
.
inconsistencies
.
append
((
r1
,
new_r1
))
# class Synchronize(Listener, Constraint):
def
validate
(
self
):
if
self
.
inconsistencies
:
raise
InconsistencyError
(
"Could not synchronize when replacing the following pairs:
%
s"
%
self
.
inconsistencies
)
return
True
# def on_import(self, op1):
return
Synchronize
# if op1 not in equiv:
# equiv[op1] = transform(op1)
# def on_prune(self, op1):
# if op1 in equiv:
# del equiv[op1]
...
...
scalar.py
浏览文件 @
13db39ca
...
@@ -5,7 +5,8 @@ import math
...
@@ -5,7 +5,8 @@ import math
from
copy
import
copy
from
copy
import
copy
import
inspect
import
inspect
from
gof
import
Result
,
GuardedOp
,
utils
import
gof
from
gof
import
Result
,
GuardedOp
,
Env
,
utils
def
as_scalar
(
x
,
name
=
None
):
def
as_scalar
(
x
,
name
=
None
):
...
@@ -20,6 +21,11 @@ def as_scalar(x, name = None):
...
@@ -20,6 +21,11 @@ def as_scalar(x, name = None):
if
isinstance
(
x
,
Scalar
):
if
isinstance
(
x
,
Scalar
):
return
x
return
x
def
constant
(
x
):
res
=
as_scalar
(
x
)
res
.
constant
=
True
return
res
class
Scalar
(
Result
):
class
Scalar
(
Result
):
...
@@ -29,6 +35,8 @@ class Scalar(Result):
...
@@ -29,6 +35,8 @@ class Scalar(Result):
self
.
dtype_specs
()
self
.
dtype_specs
()
def
__get_constant
(
self
):
def
__get_constant
(
self
):
if
not
hasattr
(
self
,
'_constant'
):
return
False
return
self
.
_constant
return
self
.
_constant
def
__set_constant
(
self
,
value
):
def
__set_constant
(
self
,
value
):
...
@@ -38,6 +46,9 @@ class Scalar(Result):
...
@@ -38,6 +46,9 @@ class Scalar(Result):
constant
=
property
(
__get_constant
,
__set_constant
)
constant
=
property
(
__get_constant
,
__set_constant
)
def
desc
(
self
):
return
(
self
.
dtype
,
self
.
data
)
def
filter
(
self
,
data
):
def
filter
(
self
,
data
):
py_type
=
self
.
dtype_specs
()[
0
]
py_type
=
self
.
dtype_specs
()[
0
]
return
py_type
(
data
)
return
py_type
(
data
)
...
@@ -58,6 +69,11 @@ class Scalar(Result):
...
@@ -58,6 +69,11 @@ class Scalar(Result):
except
KeyError
:
except
KeyError
:
raise
TypeError
(
"Unsupported dtype for
%
s:
%
s"
%
(
self
.
__class__
.
__name__
,
self
.
dtype
))
raise
TypeError
(
"Unsupported dtype for
%
s:
%
s"
%
(
self
.
__class__
.
__name__
,
self
.
dtype
))
def
c_literal
(
self
):
if
'complex'
in
self
.
dtype
:
raise
NotImplementedError
(
"No literal for complex values."
)
return
str
(
self
.
data
)
def
c_declare
(
self
,
name
,
sub
):
def
c_declare
(
self
,
name
,
sub
):
return
"""
return
"""
%(dtype)
s
%(name)
s;
%(dtype)
s
%(name)
s;
...
@@ -184,7 +200,7 @@ class ScalarMixedOp(GuardedOp):
...
@@ -184,7 +200,7 @@ class ScalarMixedOp(GuardedOp):
inputs
=
[
as_scalar
(
input
)
for
input
in
inputs
]
inputs
=
[
as_scalar
(
input
)
for
input
in
inputs
]
i_dtypes
=
[
getattr
(
input
,
'dtype'
,
None
)
for
input
in
inputs
]
i_dtypes
=
[
getattr
(
input
,
'dtype'
,
None
)
for
input
in
inputs
]
o_dtypes
=
utils
.
from_return_values
(
self
.
propagate_dtypes
(
*
i_dtypes
)
)
o_dtypes
=
self
.
propagate_dtypes
(
*
i_dtypes
)
self
.
inputs
=
inputs
self
.
inputs
=
inputs
self
.
outputs
=
[
Scalar
(
dtype
)
for
dtype
in
o_dtypes
]
self
.
outputs
=
[
Scalar
(
dtype
)
for
dtype
in
o_dtypes
]
...
@@ -217,7 +233,7 @@ class PureScalarOp(ScalarMixedOp):
...
@@ -217,7 +233,7 @@ class PureScalarOp(ScalarMixedOp):
for
dtype
in
i_dtypes
:
for
dtype
in
i_dtypes
:
if
dtype
is
None
:
if
dtype
is
None
:
raise
TypeError
(
"Expected a Scalar."
)
raise
TypeError
(
"Expected a Scalar."
)
return
self
.
cast_method
(
*
i_dtypes
)
return
[
self
.
cast_method
(
*
i_dtypes
)]
*
self
.
nout
class
UnaryScalarOp
(
PureScalarOp
):
class
UnaryScalarOp
(
PureScalarOp
):
...
@@ -383,4 +399,111 @@ modes.make_constructors(globals())
...
@@ -383,4 +399,111 @@ modes.make_constructors(globals())
def
composite
(
inputs
,
outputs
):
"""
Usage: composite(inputs, outputs)
Produces an Op class which represents the computations
between the provided inputs and outputs as a single
operation.
The operations between inputs and outputs (as given by
Env(inputs, outputs).ops()) must all be instances of
PureScalarOp.
Examples:
x, y = Scalar(), Scalar()
SquareDiff = composite([x, y], [(x - y)**2])
TimesTen = composite([x], [x * 10.0])
Neighbors = composite([x], [x - 1, x + 1])
"""
env
=
Env
(
inputs
,
outputs
)
.
clone
()
gof
.
opt
.
ConstantFinder
()
.
apply
(
env
)
inputs
,
outputs
=
env
.
inputs
,
env
.
outputs
for
op
in
env
.
ops
():
if
not
isinstance
(
op
,
PureScalarOp
):
raise
ValueError
(
"The input env to composite must be exclusively composed of PureScalarOp instances."
)
subd
=
dict
(
zip
(
inputs
,
[
"
%%
(i
%
i)s"
%
i
for
i
in
range
(
len
(
inputs
))])
+
zip
(
outputs
,
[
"
%%
(o
%
i)s"
%
i
for
i
in
range
(
len
(
outputs
))]))
for
orphan
in
env
.
orphans
():
if
orphan
.
constant
:
subd
[
orphan
]
=
orphan
.
c_literal
()
else
:
raise
ValueError
(
"All orphans in the input env to composite must be constant."
)
_c_code
=
"{
\n
"
i
=
0
j
=
0
for
op
in
env
.
toposort
():
j
+=
1
for
output
in
op
.
outputs
:
if
output
not
in
subd
:
i
+=
1
name
=
"V
%%(id)
s_tmp
%
i"
%
i
subd
[
output
]
=
name
# the c code is not robust to any other dtypes than those of the specified inputs
# a solution would be to require Composite.c_code to fill in the dtypes using
# a proper upcast
_c_code
+=
"
%
s
%
s;
\n
"
%
(
output
.
dtype_specs
()[
1
],
name
)
_c_code
+=
op
.
c_code
([
subd
[
input
]
for
input
in
op
.
inputs
],
[
subd
[
output
]
for
output
in
op
.
outputs
],
dict
(
fail
=
"
%(fail)
s"
,
id
=
"
%%(id)
s_
%
i"
%
j
))
_c_code
+=
"
\n
"
_c_code
+=
"}
\n
"
def
compose_impl
(
r
):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1)
# it will calculate *1 twice
# it also doesn't follow env.toposort but that's (presumably)
# still correct since we only have pure scalar ops
if
r
in
env
.
inputs
:
idx
=
env
.
inputs
.
index
(
r
)
return
lambda
inputs
:
inputs
[
idx
]
elif
r
in
env
.
orphans
():
return
lambda
inputs
:
r
.
data
op
=
r
.
owner
producers
=
[
compose_impl
(
input
)
for
input
in
op
.
inputs
]
return
lambda
inputs
:
op
.
impl
(
*
[
p
(
inputs
)
for
p
in
producers
])
_impls
=
[
compose_impl
(
r
)
for
r
in
env
.
outputs
]
class
Composite
(
PureScalarOp
):
nin
=
len
(
inputs
)
nout
=
len
(
outputs
)
# todo: propagate_dtypes?
def
perform
(
self
):
inputs
=
[
input
.
data
for
input
in
self
.
inputs
]
for
output
,
impl
in
zip
(
self
.
outputs
,
_impls
):
output
.
data
=
impl
(
inputs
)
def
impl
(
self
,
*
inputs
):
for
r
,
input
in
zip
(
self
.
inputs
,
inputs
):
r
.
data
=
input
self
.
perform
()
return
utils
.
to_return_values
([
output
.
data
for
output
in
self
.
outputs
])
def
grad
(
self
,
inputs
,
output_grads
):
raise
NotImplementedError
(
"grad is not implemented for Composite"
)
def
c_code
(
self
,
inames
,
onames
,
sub
):
d
=
dict
(
zip
([
"i
%
i"
%
i
for
i
in
range
(
len
(
inames
))],
inames
)
+
zip
([
"o
%
i"
%
i
for
i
in
range
(
len
(
onames
))],
onames
),
**
sub
)
return
_c_code
%
d
return
Composite
scalar_opt.py
0 → 100644
浏览文件 @
13db39ca
from
scalar
import
*
from
gof
import
PatternOptimizer
c2
=
constant
(
2.0
)
opt1
=
PatternOptimizer
((
Mul
,
'x'
,
'x'
),
(
Sqr
,
'x'
))
opt2
=
PatternOptimizer
((
Pow
,
'x'
,
c2
),
(
Sqr
,
'x'
))
tensor.py
浏览文件 @
13db39ca
...
@@ -153,7 +153,7 @@ class _Op(BaseTensorOp):
...
@@ -153,7 +153,7 @@ class _Op(BaseTensorOp):
return
self
.
c_impl
(
self
.
inputs
,
self
.
outputs
)
%
sub
return
self
.
c_impl
(
self
.
inputs
,
self
.
outputs
)
%
sub
def
c_impl
(
self
,
inputs
,
outputs
):
def
c_impl
(
self
,
inputs
,
outputs
):
raise
AbstractFunctionError
()
raise
AbstractFunctionError
(
"No c_impl for
%
s"
%
self
.
__class__
.
__name__
)
class
_Unary
:
class
_Unary
:
nin
=
1
nin
=
1
...
@@ -420,24 +420,22 @@ class Gemm(_Op):
...
@@ -420,24 +420,22 @@ class Gemm(_Op):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
c_support_code
(
self
):
def
c_support_code
(
self
):
return
blas
.
cblas_header_text
()
#return blas.cblas_header_text()
mod_str
=
"""
#ifndef MOD
#define MOD
%
#endif
"""
return
blas
.
blas_proto
()
+
mod_str
def
c_headers
(
self
):
return
[
'<iostream>'
]
def
c_libraries
(
self
):
def
c_libraries
(
self
):
return
blas
.
ldflags
()
return
blas
.
ldflags
()
def
c_var_names
(
self
):
#def c_var_names(self):
return
[[
'_z'
,
'_a'
,
'_x'
,
'_y'
,
'_b'
],
[
'_zout'
]]
# return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def
c_validate_update
(
self
,
(
_z
,
_a
,
_x
,
_y
,
_b
),
(
_zout
,
),
sub
):
def
c_validate_update
(
self
,
*
args
):
return
"""
return
""
if (
%(_zout)
s !=
%(_z)
s)
def
c_validate_update_cleanup
(
self
,
*
args
):
{
if (
%(_zout)
s)
{
Py_DECREF(
%(_zout)
s);
}
%(_zout)
s =
%(_z)
s;
Py_INCREF(
%(_zout)
s);
}
"""
%
locals
()
def
c_validate_update_cleanup
(
self
,
ignore
,
_ignore
,
__ignore
):
return
""
return
""
def
c_code
(
self
,
(
_z
,
_a
,
_x
,
_y
,
_b
),
(
_zout
,
),
sub
):
def
c_code
(
self
,
(
_z
,
_a
,
_x
,
_y
,
_b
),
(
_zout
,
),
sub
):
return
"""
return
"""
...
@@ -454,14 +452,22 @@ class Gemm(_Op):
...
@@ -454,14 +452,22 @@ class Gemm(_Op):
npy_intp* Sy =
%(_y)
s->strides;
npy_intp* Sy =
%(_y)
s->strides;
npy_intp* Sz =
%(_z)
s->strides;
npy_intp* Sz =
%(_z)
s->strides;
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
//strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if (
%(_zout)
s !=
%(_z)
s)
{
if (
%(_zout)
s)
{
Py_DECREF(
%(_zout)
s);
}
%(_zout)
s =
%(_z)
s;
Py_INCREF(
%(_zout)
s);
}
if (
%(_x)
s->nd != 2)
if (
%(_x)
s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2");
%(fail)
s;}
{PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2");
%(fail)
s;}
if (
%(_y)
s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2");
%(fail)
s;}
if (
%(_y)
s->nd != 2)
if (
%(_z)
s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2");
%(fail)
s;}
{PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2");
%(fail)
s;}
if (
%(_z)
s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2");
%(fail)
s;}
if ((
%(_a)
s->descr->type_num != PyArray_DOUBLE)
if ((
%(_a)
s->descr->type_num != PyArray_DOUBLE)
&& (
%(_a)
s->descr->type_num != PyArray_FLOAT))
&& (
%(_a)
s->descr->type_num != PyArray_FLOAT))
...
@@ -473,19 +479,19 @@ class Gemm(_Op):
...
@@ -473,19 +479,19 @@ class Gemm(_Op):
if ((
%(_x)
s->descr->type_num != PyArray_DOUBLE)
if ((
%(_x)
s->descr->type_num != PyArray_DOUBLE)
&& (
%(_x)
s->descr->type_num != PyArray_FLOAT))
&& (
%(_x)
s->descr->type_num != PyArray_FLOAT))
%(fail)
s;
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float");
%(fail)
s;}
if ((
%(_y)
s->descr->type_num != PyArray_DOUBLE)
if ((
%(_y)
s->descr->type_num != PyArray_DOUBLE)
&& (
%(_y)
s->descr->type_num != PyArray_FLOAT))
&& (
%(_y)
s->descr->type_num != PyArray_FLOAT))
%(fail)
s;
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float");
%(fail)
s;}
if ((
%(_
y
)
s->descr->type_num != PyArray_DOUBLE)
if ((
%(_
z
)
s->descr->type_num != PyArray_DOUBLE)
&& (
%(_
y
)
s->descr->type_num != PyArray_FLOAT))
&& (
%(_
z
)
s->descr->type_num != PyArray_FLOAT))
%(fail)
s;
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float");
%(fail)
s;}
if ((
%(_x)
s->descr->type_num !=
%(_y)
s->descr->type_num)
if ((
%(_x)
s->descr->type_num !=
%(_y)
s->descr->type_num)
||(
%(_x)
s->descr->type_num !=
%(_z)
s->descr->type_num))
||(
%(_x)
s->descr->type_num !=
%(_z)
s->descr->type_num))
%(fail)
s;
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same");
%(fail)
s; }
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
{
...
@@ -496,17 +502,15 @@ class Gemm(_Op):
...
@@ -496,17 +502,15 @@ class Gemm(_Op):
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
{
PyErr_SetString(PyExc_ValueError, "gemm cant run on these inputs");
PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size");
%(fail)
s;
%(fail)
s;
}
}
/*
/*
encode the stride structure of _x,_y,_z into a single integer
encode the stride structure of _x,_y,_z into a single integer
*/
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) <<
0
;
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) <<
8
;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) <<
8
;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) <<
0
;
/* create appropriate strides for malformed matrices that are row or column
/* create appropriate strides for malformed matrices that are row or column
* vectors
* vectors
...
@@ -533,18 +537,21 @@ class Gemm(_Op):
...
@@ -533,18 +537,21 @@ class Gemm(_Op):
float* x = (float*)PyArray_DATA(
%(_x)
s);
float* x = (float*)PyArray_DATA(
%(_x)
s);
float* y = (float*)PyArray_DATA(
%(_y)
s);
float* y = (float*)PyArray_DATA(
%(_y)
s);
float* z = (float*)PyArray_DATA(
%(_z)
s);
float* z = (float*)PyArray_DATA(
%(_z)
s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '
\\
n';
switch(unit)
switch(unit)
{
{
case 0x000:
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z,
sz_0); break;
case 0x000:
sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &
sz_0); break;
case 0x
001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z,
sz_0); break;
case 0x
100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &
sz_0); break;
case 0x010:
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z,
sz_0); break;
case 0x010:
sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &
sz_0); break;
case 0x
011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z,
sz_0); break;
case 0x
110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &
sz_0); break;
case 0x
100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z,
sz_1); break;
case 0x
001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &
sz_1); break;
case 0x101:
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z,
sz_1); break;
case 0x101:
sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &
sz_1); break;
case 0x
110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z,
sz_1); break;
case 0x
011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &
sz_1); break;
case 0x111:
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z,
sz_1); break;
case 0x111:
sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &
sz_1); break;
default:
%(fail)
s;
default:
PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride");
%(fail)
s;
};
};
#undef REAL
#undef REAL
}
}
...
@@ -562,17 +569,21 @@ class Gemm(_Op):
...
@@ -562,17 +569,21 @@ class Gemm(_Op):
double* x = (double*)PyArray_DATA(
%(_x)
s);
double* x = (double*)PyArray_DATA(
%(_x)
s);
double* y = (double*)PyArray_DATA(
%(_y)
s);
double* y = (double*)PyArray_DATA(
%(_y)
s);
double* z = (double*)PyArray_DATA(
%(_z)
s);
double* z = (double*)PyArray_DATA(
%(_z)
s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '
\\
n';
switch(unit)
switch(unit)
{
{
case 0x000:
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z,
sz_0); break;
case 0x000:
dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &
sz_0); break;
case 0x
001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z,
sz_0); break;
case 0x
100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &
sz_0); break;
case 0x010:
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z,
sz_0); break;
case 0x010:
dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &
sz_0); break;
case 0x
011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z,
sz_0); break;
case 0x
110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &
sz_0); break;
case 0x
100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z,
sz_1); break;
case 0x
001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &
sz_1); break;
case 0x101:
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z,
sz_1); break;
case 0x101:
dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &
sz_1); break;
case 0x
110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z,
sz_1); break;
case 0x
011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &
sz_1); break;
case 0x111:
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z,
sz_1); break;
case 0x111:
dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &
sz_1); break;
default:
%(fail)
s;
default:
PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride");
%(fail)
s;
};
};
#undef REAL
#undef REAL
}
}
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论