Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bf3df169
提交
bf3df169
authored
4月 17, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixes here and there, doc
上级
7158e6d8
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
87 行增加
和
170 行删除
+87
-170
_test_scalar_opt.py
_test_scalar_opt.py
+4
-3
base_tensor.py
base_tensor.py
+5
-1
opt.py
opt.py
+0
-0
scalar_opt.py
scalar_opt.py
+73
-2
tensor.py
tensor.py
+5
-164
没有找到文件。
_test_scalar_opt.py
浏览文件 @
bf3df169
...
@@ -42,7 +42,7 @@ class _test_opts(unittest.TestCase):
...
@@ -42,7 +42,7 @@ class _test_opts(unittest.TestCase):
# x, y, z = inputs()
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# a, b, c, d = more_inputs()
# # e = (2.0 * x) / (2.0 * y)
# # e = (2.0 * x) / (2.0 * y)
#
#
e = (2.0 * x) / (4.0 * y)
# e = (2.0 * x) / (4.0 * y)
# # e = x / (y / z)
# # e = x / (y / z)
# # e = (x * y) / x
# # e = (x * y) / x
# # e = (x / y) * (y / z) * (z / x)
# # e = (x / y) * (y / z) * (z / x)
...
@@ -71,11 +71,12 @@ class _test_opts(unittest.TestCase):
...
@@ -71,11 +71,12 @@ class _test_opts(unittest.TestCase):
# # e = (a - b) + (b - c) + (c - d)
# # e = (a - b) + (b - c) + (c - d)
# # e = x + -y
# # e = x + -y
# # e = a - b - b + a + b + c + b - c
# # e = a - b - b + a + b + c + b - c
# e = x + log(y) - x + y
# # e = x + log(y) - x + y
# e = 2.0 + x + 4.0
# g = Env([x, y, z, a, b, c, d], [e])
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# print g
# gof.ConstantFinder().optimize(g)
# gof.ConstantFinder().optimize(g)
# addfn = lambda *inputs:
reduce(lambda x, y: x + y, (0,) +
inputs)
# addfn = lambda *inputs:
sum(
inputs)
# subfn = lambda x, y: x - y
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
...
...
base_tensor.py
浏览文件 @
bf3df169
...
@@ -58,7 +58,7 @@ class BaseTensor(Result):
...
@@ -58,7 +58,7 @@ class BaseTensor(Result):
# filter
# filter
#
#
def
filter
(
self
,
arr
):
def
filter
(
self
,
arr
):
"""
cast to an L{numpy.ndarray} and ensure arr has correct rank, shape
"""
"""
Cast to an L{numpy.ndarray} and ensure arr has correct rank and shape.
"""
if
not
(
isinstance
(
arr
,
numpy
.
ndarray
)
\
if
not
(
isinstance
(
arr
,
numpy
.
ndarray
)
\
and
arr
.
dtype
==
self
.
dtype
):
and
arr
.
dtype
==
self
.
dtype
):
arr
=
numpy
.
asarray
(
arr
,
dtype
=
self
.
dtype
)
arr
=
numpy
.
asarray
(
arr
,
dtype
=
self
.
dtype
)
...
@@ -102,6 +102,9 @@ class BaseTensor(Result):
...
@@ -102,6 +102,9 @@ class BaseTensor(Result):
# Description for constant folding
# Description for constant folding
#
#
def
desc
(
self
):
def
desc
(
self
):
"""
Returns a hashable description of this BaseTensor.
"""
if
self
.
data
is
not
None
:
if
self
.
data
is
not
None
:
return
(
BaseTensor
,
self
.
dtype
,
self
.
broadcastable
,
self
.
data
.
data
[:])
return
(
BaseTensor
,
self
.
dtype
,
self
.
broadcastable
,
self
.
data
.
data
[:])
else
:
else
:
...
@@ -210,6 +213,7 @@ class BaseTensor(Result):
...
@@ -210,6 +213,7 @@ class BaseTensor(Result):
};
};
"""
"""
return
template
%
dict
(
nbits
=
64
,
half_nbits
=
32
)
+
template
%
dict
(
nbits
=
128
,
half_nbits
=
64
)
return
template
%
dict
(
nbits
=
64
,
half_nbits
=
32
)
+
template
%
dict
(
nbits
=
128
,
half_nbits
=
64
)
# todo: use C templating
############################
############################
...
...
opt.py
浏览文件 @
bf3df169
差异被折叠。
点击展开。
scalar_opt.py
浏览文件 @
bf3df169
...
@@ -23,6 +23,39 @@ logpow = Pattern((Log, (Pow, 'x', 'y')),
...
@@ -23,6 +23,39 @@ logpow = Pattern((Log, (Pow, 'x', 'y')),
class
Canonizer
(
gof
.
Optimizer
):
class
Canonizer
(
gof
.
Optimizer
):
"""
Simplification tool.
Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* mainfn, invfn, recfn: functions that behave just like the previous three
Ops, but on true scalars (e.g. their impl)
* transform: a function that maps (numerator, denominatur) where numerator
and denominator are lists of Result instances, to new lists
where further simplifications may have been applied.
Examples:
add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
Examples of optimizations mul_canonizer can perform:
x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
"""
def
__init__
(
self
,
main
,
inverse
,
reciprocal
,
mainfn
,
invfn
,
recfn
,
transform
=
None
):
def
__init__
(
self
,
main
,
inverse
,
reciprocal
,
mainfn
,
invfn
,
recfn
,
transform
=
None
):
self
.
main
=
main
self
.
main
=
main
...
@@ -37,10 +70,15 @@ class Canonizer(gof.Optimizer):
...
@@ -37,10 +70,15 @@ class Canonizer(gof.Optimizer):
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
def
canonize
(
r
):
def
canonize
(
r
):
if
r
in
env
.
inputs
or
r
in
env
.
orphans
():
if
r
in
env
.
inputs
or
r
in
env
.
orphans
():
return
return
def
flatten
(
r
,
nclients_check
=
True
):
def
flatten
(
r
,
nclients_check
=
True
):
# Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
op
=
r
.
owner
op
=
r
.
owner
if
op
is
None
or
r
in
env
.
inputs
or
r
in
env
.
orphans
():
if
op
is
None
or
r
in
env
.
inputs
or
r
in
env
.
orphans
():
return
[
r
],
[]
return
[
r
],
[]
...
@@ -50,9 +88,11 @@ class Canonizer(gof.Optimizer):
...
@@ -50,9 +88,11 @@ class Canonizer(gof.Optimizer):
nums
=
[
x
[
0
]
for
x
in
results
]
nums
=
[
x
[
0
]
for
x
in
results
]
denums
=
[
x
[
1
]
for
x
in
results
]
denums
=
[
x
[
1
]
for
x
in
results
]
elif
isinstance
(
op
,
self
.
inverse
)
and
(
not
nclients_check
or
env
.
nclients
(
r
)
==
1
):
elif
isinstance
(
op
,
self
.
inverse
)
and
(
not
nclients_check
or
env
.
nclients
(
r
)
==
1
):
# num, denum of the second argument are added to the denum, num respectively
nums
=
[
results
[
0
][
0
],
results
[
1
][
1
]]
nums
=
[
results
[
0
][
0
],
results
[
1
][
1
]]
denums
=
[
results
[
0
][
1
],
results
[
1
][
0
]]
denums
=
[
results
[
0
][
1
],
results
[
1
][
0
]]
elif
isinstance
(
op
,
self
.
reciprocal
)
and
(
not
nclients_check
or
env
.
nclients
(
r
)
==
1
):
elif
isinstance
(
op
,
self
.
reciprocal
)
and
(
not
nclients_check
or
env
.
nclients
(
r
)
==
1
):
# num, denum of the sole argument are added to the denum, num respectively
nums
=
[
results
[
0
][
1
]]
nums
=
[
results
[
0
][
1
]]
denums
=
[
results
[
0
][
0
]]
denums
=
[
results
[
0
][
0
]]
else
:
else
:
...
@@ -69,23 +109,30 @@ class Canonizer(gof.Optimizer):
...
@@ -69,23 +109,30 @@ class Canonizer(gof.Optimizer):
for
input
in
r
.
owner
.
inputs
:
for
input
in
r
.
owner
.
inputs
:
canonize
(
input
)
canonize
(
input
)
return
return
# Terms that are both in the num and denum lists cancel each other
for
d
in
list
(
denum
):
for
d
in
list
(
denum
):
if
d
in
list
(
num
):
if
d
in
list
(
num
):
# list.remove only removes the element once
num
.
remove
(
d
)
num
.
remove
(
d
)
denum
.
remove
(
d
)
denum
.
remove
(
d
)
# We identify the constants in num and denum
numct
,
num
=
utils
.
partition
(
lambda
factor
:
getattr
(
factor
,
'constant'
,
False
)
and
factor
.
data
is
not
None
,
num
)
numct
,
num
=
utils
.
partition
(
lambda
factor
:
getattr
(
factor
,
'constant'
,
False
)
and
factor
.
data
is
not
None
,
num
)
denumct
,
denum
=
utils
.
partition
(
lambda
factor
:
getattr
(
factor
,
'constant'
,
False
)
and
factor
.
data
is
not
None
,
denum
)
denumct
,
denum
=
utils
.
partition
(
lambda
factor
:
getattr
(
factor
,
'constant'
,
False
)
and
factor
.
data
is
not
None
,
denum
)
# All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
v
=
self
.
invfn
(
self
.
mainfn
(
*
[
x
.
data
for
x
in
numct
]),
self
.
mainfn
(
*
[
x
.
data
for
x
in
denumct
]))
v
=
self
.
invfn
(
self
.
mainfn
(
*
[
x
.
data
for
x
in
numct
]),
self
.
mainfn
(
*
[
x
.
data
for
x
in
denumct
]))
if
v
!=
self
.
neutral
:
if
v
!=
self
.
neutral
:
num
.
insert
(
0
,
C
(
v
))
num
.
insert
(
0
,
C
(
v
))
# We optimize the num and denum lists further if requested
if
self
.
transform
is
not
None
:
if
self
.
transform
is
not
None
:
num
,
denum
=
self
.
transform
(
env
,
num
,
denum
)
num
,
denum
=
self
.
transform
(
env
,
num
,
denum
)
def
make
(
factors
):
def
make
(
factors
):
# Combines the factors using self.main (aka Mul) depending
# on the number of elements.
n
=
len
(
factors
)
n
=
len
(
factors
)
if
n
==
0
:
if
n
==
0
:
return
None
return
None
...
@@ -98,10 +145,13 @@ class Canonizer(gof.Optimizer):
...
@@ -98,10 +145,13 @@ class Canonizer(gof.Optimizer):
if
numr
is
None
:
if
numr
is
None
:
if
denumr
is
None
:
if
denumr
is
None
:
# Everything cancelled each other so we're left with
# the neutral element.
new_r
=
Scalar
(
dtype
=
r
.
dtype
)
new_r
=
Scalar
(
dtype
=
r
.
dtype
)
new_r
.
constant
=
True
new_r
.
constant
=
True
new_r
.
data
=
self
.
neutral
new_r
.
data
=
self
.
neutral
else
:
else
:
# There's no numerator so we use reciprocal
new_r
=
self
.
reciprocal
(
denumr
)
.
out
new_r
=
self
.
reciprocal
(
denumr
)
.
out
else
:
else
:
if
denumr
is
None
:
if
denumr
is
None
:
...
@@ -109,6 +159,7 @@ class Canonizer(gof.Optimizer):
...
@@ -109,6 +159,7 @@ class Canonizer(gof.Optimizer):
else
:
else
:
new_r
=
self
.
inverse
(
numr
,
denumr
)
.
out
new_r
=
self
.
inverse
(
numr
,
denumr
)
.
out
# Hopefully this won't complain!
env
.
replace
(
r
,
new_r
)
env
.
replace
(
r
,
new_r
)
for
factor
in
num
+
denum
:
for
factor
in
num
+
denum
:
...
@@ -119,11 +170,28 @@ class Canonizer(gof.Optimizer):
...
@@ -119,11 +170,28 @@ class Canonizer(gof.Optimizer):
def
group_powers
(
env
,
num
,
denum
):
def
group_powers
(
env
,
num
,
denum
):
"""
Plugin for Canonizer: use as Canonizer(..., transform = group_powers)
Takes num, denum such that mul(*num) / mul(*denum) is in env
and searches for instances of exp(x) or x**y in order to group
together powers of the same variable. Returns num2, denum2 in
which the grouping has been done.
Note: this function does not modify env.
Examples:
group_powers([x, exp(x), exp(y)], [exp(z)]) -> [x, exp(x+y-z)], []
"""
# maps a base to the list of powers it is raised to in the
# numerator/denominator lists.
num_powers
=
{}
num_powers
=
{}
denum_powers
=
{}
denum_powers
=
{}
def
populate
(
d
,
seq
):
def
populate
(
d
,
seq
):
# For each instance of exp or pow in seq, removes it from seq
# and does d[base].append(power).
for
factor
in
list
(
seq
):
for
factor
in
list
(
seq
):
op
=
factor
.
owner
op
=
factor
.
owner
if
op
is
None
or
factor
in
env
.
inputs
or
factor
in
env
.
orphans
():
if
op
is
None
or
factor
in
env
.
inputs
or
factor
in
env
.
orphans
():
...
@@ -139,6 +207,8 @@ def group_powers(env, num, denum):
...
@@ -139,6 +207,8 @@ def group_powers(env, num, denum):
populate
(
denum_powers
,
denum
)
populate
(
denum_powers
,
denum
)
for
x
in
set
(
num_powers
.
keys
()
+
denum_powers
.
keys
()):
for
x
in
set
(
num_powers
.
keys
()
+
denum_powers
.
keys
()):
# we append base ** (num_powers[base] - denum_powers[base])
# to the num list
try
:
num_ys
=
num_powers
.
pop
(
x
)
try
:
num_ys
=
num_powers
.
pop
(
x
)
except
KeyError
:
num_ys
=
[]
except
KeyError
:
num_ys
=
[]
...
@@ -148,6 +218,7 @@ def group_powers(env, num, denum):
...
@@ -148,6 +218,7 @@ def group_powers(env, num, denum):
num_r
=
num_ys
and
add
(
*
num_ys
)
or
C
(
0
)
num_r
=
num_ys
and
add
(
*
num_ys
)
or
C
(
0
)
denum_r
=
denum_ys
and
add
(
*
denum_ys
)
or
C
(
0
)
denum_r
=
denum_ys
and
add
(
*
denum_ys
)
or
C
(
0
)
if
x
==
'e'
:
if
x
==
'e'
:
num
.
append
(
exp
(
num_r
-
denum_r
))
num
.
append
(
exp
(
num_r
-
denum_r
))
else
:
else
:
...
...
tensor.py
浏览文件 @
bf3df169
...
@@ -80,17 +80,14 @@ def astensor(data, broadcastable=None, name=None):
...
@@ -80,17 +80,14 @@ def astensor(data, broadcastable=None, name=None):
if
isinstance
(
data
,
BaseTensor
):
if
isinstance
(
data
,
BaseTensor
):
if
broadcastable
is
not
None
and
list
(
data
.
broadcastable
)
!=
list
(
broadcastable
):
if
broadcastable
is
not
None
and
list
(
data
.
broadcastable
)
!=
list
(
broadcastable
):
raise
TypeError
(
"The data to wrap as a Tensor has the wrong broadcastable pattern. Expected
%
s, got
%
s."
%
(
broadcastable
,
data
.
broadcastable
))
raise
TypeError
(
"The data to wrap as a Tensor has the wrong broadcastable pattern. Expected
%
s, got
%
s."
%
(
broadcastable
,
data
.
broadcastable
))
if
isinstance
(
data
,
Tensor
)
and
(
name
is
None
or
name
==
data
.
name
):
if
name
is
not
None
and
name
!=
data
.
name
:
return
data
raise
ValueError
(
"Cannot rename an existing Tensor."
)
else
:
return
data
t
=
Tensor
(
data
.
dtype
,
data
.
broadcastable
,
name
=
name
)
t
.
data
=
data
return
t
elif
isinstance
(
data
,
Result
):
elif
isinstance
(
data
,
Result
):
data
=
data
.
data
raise
TypeError
(
"Cannot make a Tensor out of a non-Tensor result."
)
if
data
is
None
and
broadcastable
is
None
:
if
data
is
None
and
broadcastable
is
None
:
raise
TypeError
(
"Cannot make a Tensor out of None
or a Result with no data
."
)
raise
TypeError
(
"Cannot make a Tensor out of None."
)
data
=
numpy
.
asarray
(
data
)
data
=
numpy
.
asarray
(
data
)
if
broadcastable
is
None
:
if
broadcastable
is
None
:
...
@@ -107,38 +104,6 @@ s2t.astensor = astensor
...
@@ -107,38 +104,6 @@ s2t.astensor = astensor
# Supporting Ops
# Supporting Ops
############################
############################
def
_scalar_switch
(
normal_f
,
scalar_f
,
scalar_f_reverse
=
None
):
"""a decorator for operators before broadcasting works properly"""
def
f
(
x
,
y
):
def
as_tensor
(
obj
):
if
isinstance
(
obj
,
Tensor
):
return
obj
else
:
return
astensor
(
obj
)
x
,
y
=
as_tensor
(
x
),
as_tensor
(
y
)
if
0
not
in
y
.
broadcastable
:
return
scalar_f
(
x
,
y
)
if
0
not
in
x
.
broadcastable
:
if
scalar_f_reverse
:
return
scalar_f_reverse
(
y
,
x
)
else
:
raise
TypeError
(
"You cannot do this operation on a scalar."
)
return
normal_f
(
x
,
y
)
return
f
def
_assert_same_shapes
(
x
,
*
rest
):
"""Ensure that all inputs to the function impl have the same size (foils numpy's broadcasting)"""
shape
=
x
.
shape
for
other
in
rest
:
if
other
.
shape
!=
shape
:
raise
ValueError
(
_assert_same_shapes
.
E_shape
,
shape
,
other
.
shape
)
_assert_same_shapes
.
E_shape
=
"The dimensions of the inputs do not match."
def
_assert_tensor_scalar
(
x
,
a
):
"""ensure that the second input is a scalar"""
if
numpy
.
product
(
a
.
shape
)
!=
1
:
raise
ValueError
(
"The second argument must be a scalar."
)
# this has a different name, because _as_tensor is the function which ops use
# this has a different name, because _as_tensor is the function which ops use
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
_as_tensor
=
astensor
_as_tensor
=
astensor
...
@@ -450,8 +415,6 @@ class Gemm(_Op):
...
@@ -450,8 +415,6 @@ class Gemm(_Op):
return
[
'<iostream>'
]
return
[
'<iostream>'
]
def
c_libraries
(
self
):
def
c_libraries
(
self
):
return
blas
.
ldflags
()
return
blas
.
ldflags
()
#def c_var_names(self):
# return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def
c_validate_update
(
self
,
*
args
):
def
c_validate_update
(
self
,
*
args
):
return
""
return
""
def
c_validate_update_cleanup
(
self
,
*
args
):
def
c_validate_update_cleanup
(
self
,
*
args
):
...
@@ -612,125 +575,3 @@ class Gemm(_Op):
...
@@ -612,125 +575,3 @@ class Gemm(_Op):
"""
%
dict
(
locals
(),
**
sub
)
"""
%
dict
(
locals
(),
**
sub
)
gemm
=
gof
.
op
.
constructor
(
Gemm
)
gemm
=
gof
.
op
.
constructor
(
Gemm
)
if
0
:
##########################
# Comparisons
##########################
# Less-than
class
lt_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
lt_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
# Less-than or equal
class
le_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
le_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
# Greater-than or equal
class
gt_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
gt_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
# Greater-than or equal
class
ge_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
ge_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
if
0
:
def
_broadcastable_pattern
(
pattern
):
def
factory
(
data
=
None
,
name
=
None
,
dtype
=
None
):
if
data
:
assert
len
(
data
.
shape
)
==
len
(
pattern
)
if
dtype
is
not
None
:
assert
dtype
is
data
.
dtype
dtype
=
data
.
dtype
rval
=
Tensor
(
dtype
,
pattern
,
name
)
rval
.
data
=
data
else
:
rval
=
Tensor
(
dtype
,
pattern
,
name
)
return
rval
return
factory
row
=
_broadcastable_pattern
([
1
,
0
])
col
=
_broadcastable_pattern
([
0
,
1
])
matrix
=
_broadcastable_pattern
([
0
,
0
])
if
0
:
#old __init__ code
"""Create a Tensor
If data is given:
- constant defaults to True
- if dtype is given, it must match data.dtype
- otherwise: default is data.dtype
- if broadcastable is given, len(broadcastable) must match len(data.shape)
- otherwise: if it is constant, it defaults to 1 where shape[i]==1
- if it is not constant, it defaults to 0s
If data is not given:
- constant defaults to False
"""
if
dtype
is
None
or
broadcastable
is
None
:
if
data
is
None
:
raise
TypeError
(
"Provide non-None data to complete the dtype and broadcastable flags."
)
data
=
numpy
.
asarray
(
data
)
if
constant
is
None
:
constant
=
True
dtype
=
data
.
dtype
if
constant
:
broadcastable
=
[
1
*
(
x
==
1
)
for
x
in
data
.
shape
]
else
:
broadcastable
=
[
0
]
*
len
(
data
.
shape
)
if
0
:
def
tensor__new__
(
cls
,
*
args
,
**
kwargs
):
"""__new__ is overloaded to handle the special form Tensor(x) when x is
a Tensor or an Op whose default output is a Tensor. In these cases, the
argument x is returned, and a new Tensor is not created.
"""
if
len
(
args
)
==
1
:
a
=
args
[
0
]
t
=
super
(
Tensor
,
cls
)
.
__new__
(
cls
,
*
args
,
**
kwargs
)
t
.
__init__
(
*
args
,
**
kwargs
)
return
t
# def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype)
# for dtype in dtypes:
# z = z + numpy.zeros((), dtype = dtype)
# return str(z.dtype)
# for dtype in i_dtypes:
# if dtype is None:
# raise TypeError("Expected a Tensor.")
# upcasted = upcast(*i_dtypes)
# return [upcasted] * self.nout
# # try:
# # dmap = self.destroy_map()
# # except AttributeError:
# # dmap = {}
# # rval = []
# # for i in xrange(self.nout):
# # if i in dmap:
# # destroyed = dmap[output]
# # if len(destroyed) != 1:
# # raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# # rval.append(destroyed[0])
# # else:
# # rval.append(upcasted)
# # return rval
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论