Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bf3df169
提交
bf3df169
authored
4月 17, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixes here and there, doc
上级
7158e6d8
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
169 行增加
和
216 行删除
+169
-216
_test_scalar_opt.py
_test_scalar_opt.py
+4
-3
base_tensor.py
base_tensor.py
+5
-1
opt.py
opt.py
+82
-46
scalar_opt.py
scalar_opt.py
+73
-2
tensor.py
tensor.py
+5
-164
没有找到文件。
_test_scalar_opt.py
浏览文件 @
bf3df169
...
...
@@ -42,7 +42,7 @@ class _test_opts(unittest.TestCase):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# # e = (2.0 * x) / (2.0 * y)
#
#
e = (2.0 * x) / (4.0 * y)
# e = (2.0 * x) / (4.0 * y)
# # e = x / (y / z)
# # e = (x * y) / x
# # e = (x / y) * (y / z) * (z / x)
...
...
@@ -71,11 +71,12 @@ class _test_opts(unittest.TestCase):
# # e = (a - b) + (b - c) + (c - d)
# # e = x + -y
# # e = a - b - b + a + b + c + b - c
# e = x + log(y) - x + y
# # e = x + log(y) - x + y
# e = 2.0 + x + 4.0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# addfn = lambda *inputs:
reduce(lambda x, y: x + y, (0,) +
inputs)
# addfn = lambda *inputs:
sum(
inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
...
...
base_tensor.py
浏览文件 @
bf3df169
...
...
@@ -58,7 +58,7 @@ class BaseTensor(Result):
# filter
#
def
filter
(
self
,
arr
):
"""
cast to an L{numpy.ndarray} and ensure arr has correct rank, shape
"""
"""
Cast to an L{numpy.ndarray} and ensure arr has correct rank and shape.
"""
if
not
(
isinstance
(
arr
,
numpy
.
ndarray
)
\
and
arr
.
dtype
==
self
.
dtype
):
arr
=
numpy
.
asarray
(
arr
,
dtype
=
self
.
dtype
)
...
...
@@ -102,6 +102,9 @@ class BaseTensor(Result):
# Description for constant folding
#
def
desc
(
self
):
"""
Returns a hashable description of this BaseTensor.
"""
if
self
.
data
is
not
None
:
return
(
BaseTensor
,
self
.
dtype
,
self
.
broadcastable
,
self
.
data
.
data
[:])
else
:
...
...
@@ -210,6 +213,7 @@ class BaseTensor(Result):
};
"""
return
template
%
dict
(
nbits
=
64
,
half_nbits
=
32
)
+
template
%
dict
(
nbits
=
128
,
half_nbits
=
64
)
# todo: use C templating
############################
...
...
opt.py
浏览文件 @
bf3df169
...
...
@@ -7,6 +7,19 @@ import scalar
class
InplaceOptimizer
(
opt
.
OpSpecificOptimizer
):
"""
Usage: inplace_optimizer.optimize(env)
Attempts to replace all Broadcast ops by versions of them
that operate inplace. It operates greedily: for each Broadcast
Op that is encountered, for each output, tries each input to
see if it can operate inplace on that input. If so, makes the
change and go to the next output or Broadcast Op.
Examples:
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
opclass
=
Broadcast
...
...
@@ -24,6 +37,7 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
continue
candidate_inputs
.
remove
(
candidate_input
)
op
=
new_op
baseline
=
inplace_pattern
break
inplace_optimizer
=
InplaceOptimizer
()
...
...
@@ -32,6 +46,8 @@ inplace_optimizer = InplaceOptimizer()
class
DimShuffleLifter
(
opt
.
Optimizer
):
"""
Usage: lift_dimshuffle.optimize(env)
"Lifts" DimShuffle through Broadcast operations and merges
consecutive DimShuffles. Basically, applies the following
transformations on the whole graph:
...
...
@@ -46,9 +62,6 @@ class DimShuffleLifter(opt.Optimizer):
def
apply
(
self
,
env
):
seen
=
set
()
def
merge
(
ord1
,
ord2
):
return
[
x
==
'x'
and
'x'
or
ord1
[
x
]
for
x
in
ord2
]
def
lift
(
r
):
if
r
in
seen
:
...
...
@@ -62,6 +75,7 @@ class DimShuffleLifter(opt.Optimizer):
if
isinstance
(
op
,
DimShuffle
):
in_op
=
op
.
inputs
[
0
]
.
owner
if
isinstance
(
in_op
,
DimShuffle
):
# DimShuffle(DimShuffle(x)) => DimShuffle(x)
new_order
=
[
x
==
'x'
and
'x'
or
in_op
.
new_order
[
x
]
for
x
in
op
.
new_order
]
if
new_order
==
range
(
len
(
new_order
)):
repl
=
in_op
.
inputs
[
0
]
...
...
@@ -71,6 +85,7 @@ class DimShuffleLifter(opt.Optimizer):
lift
(
repl
)
return
elif
isinstance
(
in_op
,
Broadcast
):
# DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
repl
=
Broadcast
(
in_op
.
scalar_opclass
,
[
DimShuffle
(
input
,
op
.
new_order
)
.
out
for
input
in
in_op
.
inputs
],
in_op
.
inplace_pattern
)
.
out
...
...
@@ -87,9 +102,24 @@ lift_dimshuffle = DimShuffleLifter()
def
find_cliques
(
env
,
through_broadcast
=
False
):
"""
Usage: find_cliques(env, through_broadcast = False)
Returns a list of pairs where each pair contains a list
of inputs and a list of outputs such that Env(inputs, outputs)
contains nothing but Broadcast Ops.
If through_broadcast is False, the cliques will only be
allowed to broadcast over the inputs, which means, for
example, that vector operations will not be mixed with
matrix operations.
"""
def
seek_from
(
r
):
# walks through the graph until it encounters a
# non-Broadcast operation or (if through_broadcast
# is False) a Result which needs to be broadcasted.
op
=
r
.
owner
if
r
in
env
.
inputs
\
or
r
in
env
.
orphans
()
\
...
...
@@ -103,6 +133,10 @@ def find_cliques(env, through_broadcast = False):
ret
=
set
()
if
not
through_broadcast
:
# check each dimension over all the inputs - if the broadcastable
# fields are not all 0 or all 1 for a particular dimension, then
# broadcasting will be performed along it on the inputs where the
# value is 1 and we will stop.
if
any
(
any
(
bc
)
and
not
all
(
bc
)
for
bc
in
zip
(
*
[
input
.
broadcastable
for
input
in
op
.
inputs
])):
ret
.
update
(
op
.
inputs
)
...
...
@@ -111,6 +145,7 @@ def find_cliques(env, through_broadcast = False):
for
input
in
op
.
inputs
:
res
=
seek_from
(
input
)
if
res
is
None
:
# input is a leaf of our search
ret
.
add
(
input
)
else
:
ret
.
update
(
res
)
...
...
@@ -124,11 +159,14 @@ def find_cliques(env, through_broadcast = False):
return
clique_inputs
=
seek_from
(
r
)
if
clique_inputs
is
None
:
# Not in a clique, keep going
op
=
r
.
owner
if
op
is
not
None
:
for
input
in
op
.
inputs
:
find_cliques_helper
(
input
)
else
:
# We found a clique, add it to the list and
# jump to the leaves.
cliques
.
append
((
clique_inputs
,
[
r
]))
for
input
in
clique_inputs
:
find_cliques_helper
(
input
)
...
...
@@ -142,6 +180,24 @@ def find_cliques(env, through_broadcast = False):
class
CliqueOptimizer
(
opt
.
Optimizer
):
"""
Usage: CliqueOptimizer(through_broadcast = False,
scalar_optimizer = None,
make_composite = False).optimize(env)
Finds cliques of Broadcast operations in the env and does either
or both of two things:
* Apply scalar_optimizer on the clique as if the clique was a
group of scalar operations. scalar_optimizer can be any optimization
which applies on scalars. If it is None, no optimization is done.
* Replace the clique with a single Op, optimized to perform the
computations properly. If make_composite is False, no such replacement
is done.
Note: it is recommended to run the lift_dimshuffle optimization before
this one.
"""
def
__init__
(
self
,
through_broadcast
=
False
,
scalar_optimizer
=
None
,
make_composite
=
False
):
self
.
through_broadcast
=
through_broadcast
...
...
@@ -152,20 +208,25 @@ class CliqueOptimizer(opt.Optimizer):
if
self
.
scalar_optimizer
is
None
and
not
self
.
make_composite
:
# there's nothing to do with the cliques...
return
cliques
=
find_cliques
(
env
,
self
.
through_broadcast
)
opt
=
self
.
scalar_optimizer
def
build_scalar_clique
(
r
,
env
,
equiv
):
# Maps a clique of Broadcast Ops to a clique of Scalar Ops with the same
# structure and equivalent operations. equiv contains the mapping.
if
r
in
equiv
:
return
equiv
[
r
]
op
=
r
.
owner
if
r
in
env
.
inputs
or
r
in
env
.
orphans
():
# For each leave we make a Scalar of the corresponding dtype
s
=
scalar
.
Scalar
(
dtype
=
r
.
dtype
)
_r
=
r
if
isinstance
(
r
.
owner
,
DimShuffle
)
and
all
(
x
==
'x'
for
x
in
r
.
owner
.
new_order
):
_r
=
r
.
owner
.
inputs
[
0
]
if
(
getattr
(
r
,
'constant'
,
False
)
or
getattr
(
_r
,
'constant'
,
False
))
\
and
_r
.
broadcastable
==
():
# If we have a constant tensor we map it to a constant scalar.
s
.
data
=
_r
.
data
s
.
constant
=
True
equiv
[
r
]
=
s
...
...
@@ -184,15 +245,18 @@ class CliqueOptimizer(opt.Optimizer):
s_g
=
Env
([
equiv
[
r
]
for
r
in
g
.
inputs
],
[
equiv
[
r
]
for
r
in
g
.
outputs
])
if
opt
is
not
None
:
equiv2
=
dict
()
equiv2
=
dict
()
# reverse mapping, from Scalar Op to Tensor Op
for
k
,
v
in
equiv
.
items
():
equiv2
[
v
]
=
k
def
transform
(
op
,
equiv
):
# We get a scalar op and we return an equivalent op on tensors.
return
Broadcast
(
op
.
__class__
,
[
equiv
[
input
]
for
input
in
op
.
inputs
])
s_g
.
add_feature
(
sync_to
(
env
,
equiv2
,
transform
))
s_g
.
add_feature
(
sync_to
(
env
,
equiv2
,
transform
))
# Any change to s_g will now be transferred to g
opt
.
optimize
(
s_g
)
if
self
.
make_composite
:
def
follow_inplace
(
r
):
# Tries to find the earliest r2 in g such that r destroys r2
# If no such r2 is found, returns None
op
=
r
.
owner
if
op
is
None
or
r
in
g
.
inputs
or
r
in
g
.
orphans
():
return
None
...
...
@@ -211,6 +275,8 @@ class CliqueOptimizer(opt.Optimizer):
for
i
,
output
in
enumerate
(
g
.
outputs
):
destroyed
=
follow_inplace
(
output
)
if
destroyed
is
not
None
and
destroyed
in
g
.
inputs
:
# we transfer the inplace operation only if it is
# an input that is destroyed
inplace_pattern
[
i
]
=
g
.
inputs
.
index
(
destroyed
)
C
=
scalar
.
composite
(
s_g
.
inputs
,
s_g
.
outputs
)
ec
=
Broadcast
(
C
,
g
.
inputs
,
inplace_pattern
=
inplace_pattern
)
...
...
@@ -218,6 +284,17 @@ class CliqueOptimizer(opt.Optimizer):
def
sync_to
(
target
,
equiv
,
transform
):
"""
Usage: sync_to(target, equiv, transform)
* target: an Env
* equiv: a dictionary that maps results and ops to results and ops
in target
* transform: a function that takes (op, equiv) as inputs and
returns a new op.
Returns a Feature that can be added to an Env and mirrors all
modifications to that env with modifications to the target env.
"""
class
Synchronize
(
gof
.
Listener
,
gof
.
Constraint
):
...
...
@@ -259,44 +336,3 @@ def sync_to(target, equiv, transform):
return
Synchronize
"""
This variable is used in compile.prog as the optimizer for all programs built
using either compile.single, compile.to_func, and compile.prog.
Old code::
if 0:
def optimizer(lst):
begin = gof.SeqOptimizer([])
end = gof.SeqOptimizer([gof.DummyRemover])
seq_opt = gof.SeqOptimizer(begin + lst + end)
return gof.PythonOpt(gof.MergeOptMerge(seq_opt))
if 0:
optimizer_begin = gof.SeqOptimizer([opt for name, opt in [
['double_transpose_eliminator', pattern_opt((transpose, (transpose, 'x')), 'x')],
['addxx_to_twice', pattern_opt((add_elemwise, 'x', 'x'), (twice, 'x'))],
['twice_to_itwice', op_sub(twice, itwice)],
['mulxx_to_sqr', pattern_opt((mul_elemwise, 'x', 'x'), (sqr, 'x'))],
['sqr_to_isqr', op_sub(sqr, isqr)],
['add_to_iadd', op_sub(add_elemwise, iadd_elemwise)],
['add_to_iadd_reverse', pattern_opt((add_elemwise, 'x', 'y'),
(iadd_elemwise, 'y', 'x'))]]])
# ['remove_copies', gof.OpRemover(array_copy)],
# [None, gof.DummyRemover] # has to be at the end
"""
scalar_opt.py
浏览文件 @
bf3df169
...
...
@@ -23,6 +23,39 @@ logpow = Pattern((Log, (Pow, 'x', 'y')),
class
Canonizer
(
gof
.
Optimizer
):
"""
Simplification tool.
Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* mainfn, invfn, recfn: functions that behave just like the previous three
Ops, but on true scalars (e.g. their impl)
* transform: a function that maps (numerator, denominatur) where numerator
and denominator are lists of Result instances, to new lists
where further simplifications may have been applied.
Examples:
add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
Examples of optimizations mul_canonizer can perform:
x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
"""
def
__init__
(
self
,
main
,
inverse
,
reciprocal
,
mainfn
,
invfn
,
recfn
,
transform
=
None
):
self
.
main
=
main
...
...
@@ -37,10 +70,15 @@ class Canonizer(gof.Optimizer):
def
apply
(
self
,
env
):
def
canonize
(
r
):
if
r
in
env
.
inputs
or
r
in
env
.
orphans
():
return
def
flatten
(
r
,
nclients_check
=
True
):
# Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
op
=
r
.
owner
if
op
is
None
or
r
in
env
.
inputs
or
r
in
env
.
orphans
():
return
[
r
],
[]
...
...
@@ -50,9 +88,11 @@ class Canonizer(gof.Optimizer):
nums
=
[
x
[
0
]
for
x
in
results
]
denums
=
[
x
[
1
]
for
x
in
results
]
elif
isinstance
(
op
,
self
.
inverse
)
and
(
not
nclients_check
or
env
.
nclients
(
r
)
==
1
):
# num, denum of the second argument are added to the denum, num respectively
nums
=
[
results
[
0
][
0
],
results
[
1
][
1
]]
denums
=
[
results
[
0
][
1
],
results
[
1
][
0
]]
elif
isinstance
(
op
,
self
.
reciprocal
)
and
(
not
nclients_check
or
env
.
nclients
(
r
)
==
1
):
# num, denum of the sole argument are added to the denum, num respectively
nums
=
[
results
[
0
][
1
]]
denums
=
[
results
[
0
][
0
]]
else
:
...
...
@@ -69,23 +109,30 @@ class Canonizer(gof.Optimizer):
for
input
in
r
.
owner
.
inputs
:
canonize
(
input
)
return
# Terms that are both in the num and denum lists cancel each other
for
d
in
list
(
denum
):
if
d
in
list
(
num
):
# list.remove only removes the element once
num
.
remove
(
d
)
denum
.
remove
(
d
)
# We identify the constants in num and denum
numct
,
num
=
utils
.
partition
(
lambda
factor
:
getattr
(
factor
,
'constant'
,
False
)
and
factor
.
data
is
not
None
,
num
)
denumct
,
denum
=
utils
.
partition
(
lambda
factor
:
getattr
(
factor
,
'constant'
,
False
)
and
factor
.
data
is
not
None
,
denum
)
# All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
v
=
self
.
invfn
(
self
.
mainfn
(
*
[
x
.
data
for
x
in
numct
]),
self
.
mainfn
(
*
[
x
.
data
for
x
in
denumct
]))
if
v
!=
self
.
neutral
:
num
.
insert
(
0
,
C
(
v
))
# We optimize the num and denum lists further if requested
if
self
.
transform
is
not
None
:
num
,
denum
=
self
.
transform
(
env
,
num
,
denum
)
def
make
(
factors
):
# Combines the factors using self.main (aka Mul) depending
# on the number of elements.
n
=
len
(
factors
)
if
n
==
0
:
return
None
...
...
@@ -98,10 +145,13 @@ class Canonizer(gof.Optimizer):
if
numr
is
None
:
if
denumr
is
None
:
# Everything cancelled each other so we're left with
# the neutral element.
new_r
=
Scalar
(
dtype
=
r
.
dtype
)
new_r
.
constant
=
True
new_r
.
data
=
self
.
neutral
else
:
# There's no numerator so we use reciprocal
new_r
=
self
.
reciprocal
(
denumr
)
.
out
else
:
if
denumr
is
None
:
...
...
@@ -109,6 +159,7 @@ class Canonizer(gof.Optimizer):
else
:
new_r
=
self
.
inverse
(
numr
,
denumr
)
.
out
# Hopefully this won't complain!
env
.
replace
(
r
,
new_r
)
for
factor
in
num
+
denum
:
...
...
@@ -119,11 +170,28 @@ class Canonizer(gof.Optimizer):
def
group_powers
(
env
,
num
,
denum
):
"""
Plugin for Canonizer: use as Canonizer(..., transform = group_powers)
Takes num, denum such that mul(*num) / mul(*denum) is in env
and searches for instances of exp(x) or x**y in order to group
together powers of the same variable. Returns num2, denum2 in
which the grouping has been done.
Note: this function does not modify env.
Examples:
group_powers([x, exp(x), exp(y)], [exp(z)]) -> [x, exp(x+y-z)], []
"""
# maps a base to the list of powers it is raised to in the
# numerator/denominator lists.
num_powers
=
{}
denum_powers
=
{}
def
populate
(
d
,
seq
):
# For each instance of exp or pow in seq, removes it from seq
# and does d[base].append(power).
for
factor
in
list
(
seq
):
op
=
factor
.
owner
if
op
is
None
or
factor
in
env
.
inputs
or
factor
in
env
.
orphans
():
...
...
@@ -139,6 +207,8 @@ def group_powers(env, num, denum):
populate
(
denum_powers
,
denum
)
for
x
in
set
(
num_powers
.
keys
()
+
denum_powers
.
keys
()):
# we append base ** (num_powers[base] - denum_powers[base])
# to the num list
try
:
num_ys
=
num_powers
.
pop
(
x
)
except
KeyError
:
num_ys
=
[]
...
...
@@ -148,6 +218,7 @@ def group_powers(env, num, denum):
num_r
=
num_ys
and
add
(
*
num_ys
)
or
C
(
0
)
denum_r
=
denum_ys
and
add
(
*
denum_ys
)
or
C
(
0
)
if
x
==
'e'
:
num
.
append
(
exp
(
num_r
-
denum_r
))
else
:
...
...
tensor.py
浏览文件 @
bf3df169
...
...
@@ -80,17 +80,14 @@ def astensor(data, broadcastable=None, name=None):
if
isinstance
(
data
,
BaseTensor
):
if
broadcastable
is
not
None
and
list
(
data
.
broadcastable
)
!=
list
(
broadcastable
):
raise
TypeError
(
"The data to wrap as a Tensor has the wrong broadcastable pattern. Expected
%
s, got
%
s."
%
(
broadcastable
,
data
.
broadcastable
))
if
isinstance
(
data
,
Tensor
)
and
(
name
is
None
or
name
==
data
.
name
):
return
data
else
:
t
=
Tensor
(
data
.
dtype
,
data
.
broadcastable
,
name
=
name
)
t
.
data
=
data
return
t
if
name
is
not
None
and
name
!=
data
.
name
:
raise
ValueError
(
"Cannot rename an existing Tensor."
)
return
data
elif
isinstance
(
data
,
Result
):
data
=
data
.
data
raise
TypeError
(
"Cannot make a Tensor out of a non-Tensor result."
)
if
data
is
None
and
broadcastable
is
None
:
raise
TypeError
(
"Cannot make a Tensor out of None
or a Result with no data
."
)
raise
TypeError
(
"Cannot make a Tensor out of None."
)
data
=
numpy
.
asarray
(
data
)
if
broadcastable
is
None
:
...
...
@@ -107,38 +104,6 @@ s2t.astensor = astensor
# Supporting Ops
############################
def
_scalar_switch
(
normal_f
,
scalar_f
,
scalar_f_reverse
=
None
):
"""a decorator for operators before broadcasting works properly"""
def
f
(
x
,
y
):
def
as_tensor
(
obj
):
if
isinstance
(
obj
,
Tensor
):
return
obj
else
:
return
astensor
(
obj
)
x
,
y
=
as_tensor
(
x
),
as_tensor
(
y
)
if
0
not
in
y
.
broadcastable
:
return
scalar_f
(
x
,
y
)
if
0
not
in
x
.
broadcastable
:
if
scalar_f_reverse
:
return
scalar_f_reverse
(
y
,
x
)
else
:
raise
TypeError
(
"You cannot do this operation on a scalar."
)
return
normal_f
(
x
,
y
)
return
f
def
_assert_same_shapes
(
x
,
*
rest
):
"""Ensure that all inputs to the function impl have the same size (foils numpy's broadcasting)"""
shape
=
x
.
shape
for
other
in
rest
:
if
other
.
shape
!=
shape
:
raise
ValueError
(
_assert_same_shapes
.
E_shape
,
shape
,
other
.
shape
)
_assert_same_shapes
.
E_shape
=
"The dimensions of the inputs do not match."
def
_assert_tensor_scalar
(
x
,
a
):
"""ensure that the second input is a scalar"""
if
numpy
.
product
(
a
.
shape
)
!=
1
:
raise
ValueError
(
"The second argument must be a scalar."
)
# this has a different name, because _as_tensor is the function which ops use
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
_as_tensor
=
astensor
...
...
@@ -450,8 +415,6 @@ class Gemm(_Op):
return
[
'<iostream>'
]
def
c_libraries
(
self
):
return
blas
.
ldflags
()
#def c_var_names(self):
# return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def
c_validate_update
(
self
,
*
args
):
return
""
def
c_validate_update_cleanup
(
self
,
*
args
):
...
...
@@ -612,125 +575,3 @@ class Gemm(_Op):
"""
%
dict
(
locals
(),
**
sub
)
gemm
=
gof
.
op
.
constructor
(
Gemm
)
if
0
:
##########################
# Comparisons
##########################
# Less-than
class
lt_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
lt_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
# Less-than or equal
class
le_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
le_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
# Greater-than or equal
class
gt_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
gt_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
# Greater-than or equal
class
ge_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
ge_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
if
0
:
def
_broadcastable_pattern
(
pattern
):
def
factory
(
data
=
None
,
name
=
None
,
dtype
=
None
):
if
data
:
assert
len
(
data
.
shape
)
==
len
(
pattern
)
if
dtype
is
not
None
:
assert
dtype
is
data
.
dtype
dtype
=
data
.
dtype
rval
=
Tensor
(
dtype
,
pattern
,
name
)
rval
.
data
=
data
else
:
rval
=
Tensor
(
dtype
,
pattern
,
name
)
return
rval
return
factory
row
=
_broadcastable_pattern
([
1
,
0
])
col
=
_broadcastable_pattern
([
0
,
1
])
matrix
=
_broadcastable_pattern
([
0
,
0
])
if
0
:
#old __init__ code
"""Create a Tensor
If data is given:
- constant defaults to True
- if dtype is given, it must match data.dtype
- otherwise: default is data.dtype
- if broadcastable is given, len(broadcastable) must match len(data.shape)
- otherwise: if it is constant, it defaults to 1 where shape[i]==1
- if it is not constant, it defaults to 0s
If data is not given:
- constant defaults to False
"""
if
dtype
is
None
or
broadcastable
is
None
:
if
data
is
None
:
raise
TypeError
(
"Provide non-None data to complete the dtype and broadcastable flags."
)
data
=
numpy
.
asarray
(
data
)
if
constant
is
None
:
constant
=
True
dtype
=
data
.
dtype
if
constant
:
broadcastable
=
[
1
*
(
x
==
1
)
for
x
in
data
.
shape
]
else
:
broadcastable
=
[
0
]
*
len
(
data
.
shape
)
if
0
:
def
tensor__new__
(
cls
,
*
args
,
**
kwargs
):
"""__new__ is overloaded to handle the special form Tensor(x) when x is
a Tensor or an Op whose default output is a Tensor. In these cases, the
argument x is returned, and a new Tensor is not created.
"""
if
len
(
args
)
==
1
:
a
=
args
[
0
]
t
=
super
(
Tensor
,
cls
)
.
__new__
(
cls
,
*
args
,
**
kwargs
)
t
.
__init__
(
*
args
,
**
kwargs
)
return
t
# def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype)
# for dtype in dtypes:
# z = z + numpy.zeros((), dtype = dtype)
# return str(z.dtype)
# for dtype in i_dtypes:
# if dtype is None:
# raise TypeError("Expected a Tensor.")
# upcasted = upcast(*i_dtypes)
# return [upcasted] * self.nout
# # try:
# # dmap = self.destroy_map()
# # except AttributeError:
# # dmap = {}
# # rval = []
# # for i in xrange(self.nout):
# # if i in dmap:
# # destroyed = dmap[output]
# # if len(destroyed) != 1:
# # raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# # rval.append(destroyed[0])
# # else:
# # rval.append(upcasted)
# # return rval
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论