Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
59aa5470
提交
59aa5470
authored
4月 16, 2008
作者:
bergstrj@iro.umontreal.ca
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added care wrt non-int,non-float scalar types
上级
62c90bd2
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
48 行增加
和
35 行删除
+48
-35
scalar.py
scalar.py
+48
-35
没有找到文件。
scalar.py
浏览文件 @
59aa5470
...
@@ -58,14 +58,15 @@ class Scalar(Result):
...
@@ -58,14 +58,15 @@ class Scalar(Result):
def
dtype_specs
(
self
):
def
dtype_specs
(
self
):
try
:
try
:
return
{
'float32'
:
(
float
,
'npy_float32'
,
'PyFloat_Check'
,
'PyFloat_AsDouble'
,
'PyFloat_FromDouble'
),
return
{
'float32'
:
(
numpy
.
float32
,
'npy_float32'
,
'PyFloat_Check'
,
'PyFloat_AsDouble'
,
'PyFloat_FromDouble'
),
'float64'
:
(
float
,
'npy_float64'
,
'PyFloat_Check'
,
'PyFloat_AsDouble'
,
'PyFloat_FromDouble'
),
'float64'
:
(
numpy
.
float64
,
'npy_float64'
,
'PyFloat_Check'
,
'PyFloat_AsDouble'
,
'PyFloat_FromDouble'
),
'int8'
:
(
int
,
'npy_int8'
,
'PyInt_Check'
,
'PyInt_AsLong'
,
'PyInt_FromLong'
),
'complex128'
:
(
numpy
.
complex128
,
'theano_complex128'
,
'PyComplex_Check'
,
'PyComplex_AsCComplex'
,
'PyComplex_FromCComplex'
),
'int16'
:
(
int
,
'npy_int16'
,
'PyInt_Check'
,
'PyInt_AsLong'
,
'PyInt_FromLong'
),
'complex64'
:
(
numpy
.
complex64
,
'theano_complex64'
,
None
,
None
,
None
),
'int32'
:
(
int
,
'npy_int32'
,
'PyInt_Check'
,
'PyInt_AsLong'
,
'PyInt_FromLong'
),
'int8'
:
(
numpy
.
int8
,
'npy_int8'
,
'PyInt_Check'
,
'PyInt_AsLong'
,
'PyInt_FromLong'
),
'int64'
:
(
int
,
'npy_int64'
,
'PyInt_Check'
,
'PyInt_AsLong'
,
'PyInt_FromLong'
),
'int16'
:
(
numpy
.
int16
,
'npy_int16'
,
'PyInt_Check'
,
'PyInt_AsLong'
,
'PyInt_FromLong'
),
'complex128'
:
(
complex
,
'theano_complex128'
,
'PyComplex_Check'
,
'PyComplex_AsCComplex'
,
'PyComplex_FromCComplex'
),
'int32'
:
(
numpy
.
int32
,
'npy_int32'
,
'PyInt_Check'
,
'PyInt_AsLong'
,
'PyInt_FromLong'
),
'complex64'
:
(
complex
,
'theano_complex64'
,
None
,
None
,
None
)}[
self
.
dtype
]
'int64'
:
(
numpy
.
int64
,
'npy_int64'
,
'PyInt_Check'
,
'PyInt_AsLong'
,
'PyInt_FromLong'
)
}[
self
.
dtype
]
except
KeyError
:
except
KeyError
:
raise
TypeError
(
"Unsupported dtype for
%
s:
%
s"
%
(
self
.
__class__
.
__name__
,
self
.
dtype
))
raise
TypeError
(
"Unsupported dtype for
%
s:
%
s"
%
(
self
.
__class__
.
__name__
,
self
.
dtype
))
...
@@ -148,9 +149,7 @@ class Scalar(Result):
...
@@ -148,9 +149,7 @@ class Scalar(Result):
return
template
%
dict
(
nbits
=
64
,
half_nbits
=
32
)
+
template
%
dict
(
nbits
=
128
,
half_nbits
=
64
)
return
template
%
dict
(
nbits
=
64
,
half_nbits
=
32
)
+
template
%
dict
(
nbits
=
128
,
half_nbits
=
64
)
def
__copy__
(
self
):
def
__copy__
(
self
):
"""
"""Return a copy of this instance (with its own attributes)"""
Return a copy of this instance (with its own attributes)
"""
cpy
=
self
.
__class__
(
self
.
dtype
,
self
.
name
)
cpy
=
self
.
__class__
(
self
.
dtype
,
self
.
name
)
cpy
.
data
=
self
.
data
cpy
.
data
=
self
.
data
return
cpy
return
cpy
...
@@ -207,11 +206,16 @@ class ScalarOp(GuardedOp):
...
@@ -207,11 +206,16 @@ class ScalarOp(GuardedOp):
inputs
=
[
as_scalar
(
input
)
for
input
in
inputs
]
inputs
=
[
as_scalar
(
input
)
for
input
in
inputs
]
i_dtypes
=
[
getattr
(
input
,
'dtype'
,
None
)
for
input
in
inputs
]
i_dtypes
=
[
getattr
(
input
,
'dtype'
,
None
)
for
input
in
inputs
]
o_dtypes
=
[
upcast
(
*
i_dtypes
)]
*
self
.
nout
o_dtypes
=
self
.
output_dtypes
(
*
i_dtypes
)
self
.
inputs
=
inputs
self
.
inputs
=
inputs
self
.
outputs
=
[
Scalar
(
dtype
)
for
dtype
in
o_dtypes
]
self
.
outputs
=
[
Scalar
(
dtype
)
for
dtype
in
o_dtypes
]
def
output_dtypes
(
self
,
*
dtypes
):
if
self
.
nout
!=
1
:
raise
NotImplementedError
()
return
upcast
(
*
dtypes
),
def
impl
(
self
,
*
inputs
):
def
impl
(
self
,
*
inputs
):
raise
AbstractFunctionError
()
raise
AbstractFunctionError
()
...
@@ -232,6 +236,13 @@ class UnaryScalarOp(ScalarOp):
...
@@ -232,6 +236,13 @@ class UnaryScalarOp(ScalarOp):
class
BinaryScalarOp
(
ScalarOp
):
class
BinaryScalarOp
(
ScalarOp
):
nin
=
2
nin
=
2
class
FloatUnaryScalarOp
(
UnaryScalarOp
):
def
output_dtypes
(
self
,
input_dtype
):
if
'int'
in
input_dtype
:
return
'float64'
,
if
'float'
in
input_dtype
:
return
input_dtype
,
raise
NotImplementedError
()
class
Add
(
ScalarOp
):
class
Add
(
ScalarOp
):
identity
=
0
identity
=
0
...
@@ -318,6 +329,7 @@ class Neg(UnaryScalarOp):
...
@@ -318,6 +329,7 @@ class Neg(UnaryScalarOp):
return
"
%(z)
s = -
%(x)
s;"
%
locals
()
return
"
%(z)
s = -
%(x)
s;"
%
locals
()
class
Abs
(
UnaryScalarOp
):
class
Abs
(
UnaryScalarOp
):
#TODO: for complex input, output is some flavour of float
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
abs
(
x
)
return
numpy
.
abs
(
x
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
...
@@ -333,22 +345,24 @@ class Abs(UnaryScalarOp):
...
@@ -333,22 +345,24 @@ class Abs(UnaryScalarOp):
class
Sgn
(
UnaryScalarOp
):
class
Sgn
(
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
abs
(
x
)
/
x
#casting to output type is handled by filter
return
1.0
if
x
>=
0
else
-
1.0
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
return
None
,
return
None
,
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
return
"
%(z)
s =
%(x)
s/
%(prefix)
sabs(
%(x)
s);"
\
#casting is done by compiler
%
dict
(
locals
(),
prefix
=
'float'
in
self
.
inputs
[
0
]
.
dtype
and
'f'
or
''
)
# TODO: C use copysign
#TODO: use copysign
return
"
%(z)
s = (
%(x)
s >= 0) ? 1.0 : -1.0;"
%
locals
()
class
Inv
(
UnaryScalarOp
):
class
Inv
(
Float
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
1
/
x
return
1
.0
/
x
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
return
-
gz
/
(
x
*
x
),
return
-
gz
/
(
x
*
x
),
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
return
"
%(z)
s = 1 /
%(x)
s;"
%
locals
()
return
"
%(z)
s = 1
.0
/
%(x)
s;"
%
locals
()
class
Log
(
UnaryScalarOp
):
class
Log
(
Float
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
math
.
log
(
x
)
return
math
.
log
(
x
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
...
@@ -356,7 +370,7 @@ class Log(UnaryScalarOp):
...
@@ -356,7 +370,7 @@ class Log(UnaryScalarOp):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
return
"
%(z)
s = log(
%(x)
s);"
%
locals
()
return
"
%(z)
s = log(
%(x)
s);"
%
locals
()
class
Log2
(
UnaryScalarOp
):
class
Log2
(
Float
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
log2
(
x
)
return
numpy
.
log2
(
x
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
...
@@ -364,7 +378,7 @@ class Log2(UnaryScalarOp):
...
@@ -364,7 +378,7 @@ class Log2(UnaryScalarOp):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
return
"
%(z)
s = log2(
%(x)
s);"
%
locals
()
return
"
%(z)
s = log2(
%(x)
s);"
%
locals
()
class
Exp
(
UnaryScalarOp
):
class
Exp
(
Float
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
math
.
exp
(
x
)
return
math
.
exp
(
x
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
...
@@ -380,7 +394,7 @@ class Sqr(UnaryScalarOp):
...
@@ -380,7 +394,7 @@ class Sqr(UnaryScalarOp):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
return
"
%(z)
s =
%(x)
s *
%(x)
s;"
%
locals
()
return
"
%(z)
s =
%(x)
s *
%(x)
s;"
%
locals
()
class
Sqrt
(
UnaryScalarOp
):
class
Sqrt
(
Float
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
math
.
sqrt
(
x
)
return
math
.
sqrt
(
x
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
...
@@ -388,7 +402,7 @@ class Sqrt(UnaryScalarOp):
...
@@ -388,7 +402,7 @@ class Sqrt(UnaryScalarOp):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
return
"
%(z)
s = sqrt(
%(x)
s);"
%
locals
()
return
"
%(z)
s = sqrt(
%(x)
s);"
%
locals
()
class
Cos
(
UnaryScalarOp
):
class
Cos
(
Float
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
math
.
cos
(
x
)
return
math
.
cos
(
x
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
...
@@ -396,7 +410,7 @@ class Cos(UnaryScalarOp):
...
@@ -396,7 +410,7 @@ class Cos(UnaryScalarOp):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
return
"
%(z)
s = cos(
%(x)
s);"
%
locals
()
return
"
%(z)
s = cos(
%(x)
s);"
%
locals
()
class
Sin
(
UnaryScalarOp
):
class
Sin
(
Float
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
math
.
sin
(
x
)
return
math
.
sin
(
x
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
...
@@ -404,15 +418,15 @@ class Sin(UnaryScalarOp):
...
@@ -404,15 +418,15 @@ class Sin(UnaryScalarOp):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
return
"
%(z)
s = sin(
%(x)
s);"
%
locals
()
return
"
%(z)
s = sin(
%(x)
s);"
%
locals
()
class
Tan
(
UnaryScalarOp
):
class
Tan
(
Float
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
math
.
tan
(
x
)
return
math
.
tan
(
x
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
raise
NotImplementedError
(
'lazy'
)
raise
NotImplementedError
()
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
return
"
%(z)
s = tan(
%(x)
s);"
%
locals
()
return
"
%(z)
s = tan(
%(x)
s);"
%
locals
()
class
Cosh
(
UnaryScalarOp
):
class
Cosh
(
Float
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
math
.
cosh
(
x
)
return
math
.
cosh
(
x
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
...
@@ -420,15 +434,15 @@ class Cosh(UnaryScalarOp):
...
@@ -420,15 +434,15 @@ class Cosh(UnaryScalarOp):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
return
"
%(z)
s = cosh(
%(x)
s);"
%
locals
()
return
"
%(z)
s = cosh(
%(x)
s);"
%
locals
()
class
Sinh
(
UnaryScalarOp
):
class
Sinh
(
Float
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
math
.
sin
(
x
)
return
math
.
sin
h
(
x
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
r
eturn
-
gz
*
cos
(
x
),
r
aise
NotImplementedError
()
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
(
x
,
),
(
z
,
),
sub
):
return
"
%(z)
s = sin(
%(x)
s);"
%
locals
()
return
"
%(z)
s = sin(
%(x)
s);"
%
locals
()
class
Tanh
(
UnaryScalarOp
):
class
Tanh
(
Float
UnaryScalarOp
):
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
math
.
tanh
(
x
)
return
math
.
tanh
(
x
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
...
@@ -496,9 +510,6 @@ def composite(inputs, outputs):
...
@@ -496,9 +510,6 @@ def composite(inputs, outputs):
i
+=
1
i
+=
1
name
=
"V
%%(id)
s_tmp
%
i"
%
i
name
=
"V
%%(id)
s_tmp
%
i"
%
i
subd
[
output
]
=
name
subd
[
output
]
=
name
# the c code is not robust to any other dtypes than those of the specified inputs
# a solution would be to require Composite.c_code to fill in the dtypes using
# a proper upcast
_c_code
+=
"
%
s
%
s;
\n
"
%
(
output
.
dtype_specs
()[
1
],
name
)
_c_code
+=
"
%
s
%
s;
\n
"
%
(
output
.
dtype_specs
()[
1
],
name
)
_c_code
+=
op
.
c_code
([
subd
[
input
]
for
input
in
op
.
inputs
],
_c_code
+=
op
.
c_code
([
subd
[
input
]
for
input
in
op
.
inputs
],
[
subd
[
output
]
for
output
in
op
.
outputs
],
[
subd
[
output
]
for
output
in
op
.
outputs
],
...
@@ -529,7 +540,9 @@ def composite(inputs, outputs):
...
@@ -529,7 +540,9 @@ def composite(inputs, outputs):
nin
=
len
(
inputs
)
nin
=
len
(
inputs
)
nout
=
len
(
outputs
)
nout
=
len
(
outputs
)
# todo: propagate_dtypes?
def
output_dtypes
(
self
,
*
input_dtypes
):
assert
input_dtypes
==
tuple
([
input
.
dtype
for
input
in
inputs
])
return
[
output
.
dtype
for
dtype
in
outputs
]
def
perform
(
self
):
def
perform
(
self
):
inputs
=
[
input
.
data
for
input
in
self
.
inputs
]
inputs
=
[
input
.
data
for
input
in
self
.
inputs
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论