Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
286e1df5
提交
286e1df5
authored
8月 30, 2012
作者:
Ian Goodfellow
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix (mostly pre-existing) pep8 violations in tensor.basic
上级
ba4ba5ef
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
25 行增加
和
16 行删除
+25
-16
basic.py
theano/tensor/basic.py
+25
-16
没有找到文件。
theano/tensor/basic.py
浏览文件 @
286e1df5
...
@@ -48,6 +48,7 @@ continuous_dtypes = map(str, scal.continuous_types)
...
@@ -48,6 +48,7 @@ continuous_dtypes = map(str, scal.continuous_types)
discrete_dtypes
=
map
(
str
,
scal
.
discrete_types
)
discrete_dtypes
=
map
(
str
,
scal
.
discrete_types
)
all_dtypes
=
map
(
str
,
scal
.
all_types
)
all_dtypes
=
map
(
str
,
scal
.
all_types
)
class
ShapeError
(
Exception
):
class
ShapeError
(
Exception
):
"""Raised when the shape cannot be computed."""
"""Raised when the shape cannot be computed."""
pass
pass
...
@@ -395,6 +396,7 @@ def constant(x, name=None, ndim=None, dtype=None):
...
@@ -395,6 +396,7 @@ def constant(x, name=None, ndim=None, dtype=None):
return
constant_or_value
(
x
,
rtype
=
TensorConstant
,
name
=
name
,
ndim
=
ndim
,
return
constant_or_value
(
x
,
rtype
=
TensorConstant
,
name
=
name
,
ndim
=
ndim
,
dtype
=
dtype
)
dtype
=
dtype
)
def
_obj_is_wrappable_as_tensor
(
x
):
def
_obj_is_wrappable_as_tensor
(
x
):
try
:
try
:
constant
(
x
)
constant
(
x
)
...
@@ -406,7 +408,7 @@ def _obj_is_wrappable_as_tensor(x):
...
@@ -406,7 +408,7 @@ def _obj_is_wrappable_as_tensor(x):
def
_wrap_tensor_into_member
(
x
):
def
_wrap_tensor_into_member
(
x
):
return
compile
.
module
.
Member
(
constant
(
x
))
return
compile
.
module
.
Member
(
constant
(
x
))
compile
.
module
.
register_wrapper
(
_obj_is_wrappable_as_tensor
,
compile
.
module
.
register_wrapper
(
_obj_is_wrappable_as_tensor
,
_wrap_tensor_into_member
,
no_warn
=
True
)
_wrap_tensor_into_member
,
no_warn
=
True
)
if
int
(
config
.
tensor
.
cmp_sloppy
)
>
1
:
if
int
(
config
.
tensor
.
cmp_sloppy
)
>
1
:
...
@@ -1503,10 +1505,9 @@ class _tensor_py_operators:
...
@@ -1503,10 +1505,9 @@ class _tensor_py_operators:
"""
"""
if
ndim
is
not
None
:
if
ndim
is
not
None
:
if
not
isinstance
(
ndim
,
int
):
if
not
isinstance
(
ndim
,
int
):
raise
ValueError
(
"Expected ndim to be an integer, is "
\
raise
ValueError
(
"Expected ndim to be an integer, is "
\
+
str
(
type
(
ndim
)))
+
str
(
type
(
ndim
)))
return
reshape
(
self
,
shape
,
ndim
=
ndim
)
return
reshape
(
self
,
shape
,
ndim
=
ndim
)
...
@@ -1803,7 +1804,6 @@ class TensorConstant(_tensor_py_operators, Constant):
...
@@ -1803,7 +1804,6 @@ class TensorConstant(_tensor_py_operators, Constant):
TensorType
.
Constant
=
TensorConstant
TensorType
.
Constant
=
TensorConstant
Tensor
=
TensorType
Tensor
=
TensorType
...
@@ -1817,6 +1817,7 @@ elemwise.TensorConstant = TensorConstant
...
@@ -1817,6 +1817,7 @@ elemwise.TensorConstant = TensorConstant
# Utilities
# Utilities
#########################
#########################
def
_redefine
(
real_symbol_value
,
module
=
'tensor'
):
def
_redefine
(
real_symbol_value
,
module
=
'tensor'
):
"""Replace the value associated with a function symbol.
"""Replace the value associated with a function symbol.
...
@@ -2062,6 +2063,7 @@ def cast(x, dtype):
...
@@ -2062,6 +2063,7 @@ def cast(x, dtype):
# Unary Operations
# Unary Operations
##########################
##########################
class
Shape
(
Op
):
class
Shape
(
Op
):
"""
"""
L{Op} to return the shape of a matrix.
L{Op} to return the shape of a matrix.
...
@@ -2333,16 +2335,16 @@ class MaxAndArgmax(Op):
...
@@ -2333,16 +2335,16 @@ class MaxAndArgmax(Op):
#if the op is totally disconnected, so are its inputs
#if the op is totally disconnected, so are its inputs
if
g_max_disconnected
and
g_max_idx_disconnected
:
if
g_max_disconnected
and
g_max_idx_disconnected
:
return
[
DisconnectedType
()(),
DisconnectedType
()()
]
return
[
DisconnectedType
()(),
DisconnectedType
()()
]
axis_grad
=
grad_undefined
(
self
,
1
,
axis
,
axis_grad
=
grad_undefined
(
self
,
1
,
axis
,
"argmax is not defined for non-integer axes so"
"argmax is not defined for non-integer axes so"
" argmax(x, axis+eps) is undefined"
)
" argmax(x, axis+eps) is undefined"
)
#if the max is disconnected but the argmax is not,
#if the max is disconnected but the argmax is not,
#the gradient on its inputs is zero
#the gradient on its inputs is zero
if
g_max_disconnected
:
if
g_max_disconnected
:
return
[
x
.
zeros_like
(),
axis_grad
]
return
[
x
.
zeros_like
(),
axis_grad
]
xmax
=
max
(
x
,
axis
)
xmax
=
max
(
x
,
axis
)
# Raise the g_max and xmax to the same number of dim as the input.
# Raise the g_max and xmax to the same number of dim as the input.
...
@@ -2875,6 +2877,7 @@ def complex_from_polar(abs, angle):
...
@@ -2875,6 +2877,7 @@ def complex_from_polar(abs, angle):
# Misc
# Misc
##########################
##########################
#fill, _fill_inplace = _elemwise(scal.second, 'fill',
#fill, _fill_inplace = _elemwise(scal.second, 'fill',
#"""fill WRITEME (elemwise)""")
#"""fill WRITEME (elemwise)""")
@_scal_elemwise
@_scal_elemwise
...
@@ -2943,7 +2946,7 @@ class Eye(gof.Op):
...
@@ -2943,7 +2946,7 @@ class Eye(gof.Op):
return
[
out_shape
]
return
[
out_shape
]
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
return
[
grad_undefined
(
self
,
i
,
inp
[
i
])
for
i
in
xrange
(
3
)
]
return
[
grad_undefined
(
self
,
i
,
inp
[
i
])
for
i
in
xrange
(
3
)
]
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
dtype
==
other
.
dtype
return
type
(
self
)
==
type
(
other
)
and
self
.
dtype
==
other
.
dtype
...
@@ -3418,15 +3421,20 @@ def var(input, axis=None, keepdims=False):
...
@@ -3418,15 +3421,20 @@ def var(input, axis=None, keepdims=False):
@constructor
@constructor
def
std
(
input
,
axis
=
None
,
keepdims
=
False
):
def
std
(
input
,
axis
=
None
,
keepdims
=
False
):
"""
"""
Computes the standard deviation along the given axis(es) of a tensor `input`.
Computes the standard deviation along the given axis(es)
of a tensor `input`.
:param axis: Compute the standard deviation along this axis of the tensor.
:param axis: Compute the standard deviation along this
axis of the tensor.
None means all axes (like numpy).
None means all axes (like numpy).
:type axis: None or int or (list of int) (see `Sum`)
:type axis: None or int or (list of int) (see `Sum`)
:param keepdims: If this is set to True, the axes which are reduced are
:param keepdims: If this is set to True, the axes
left in the result as dimensions with size one. With this option,
which are reduced are
the result will broadcast correctly against the original tensor.
left in the result as dimensions with size one.
With this option,
the result will broadcast correctly against the
original tensor.
"""
"""
return
sqrt
(
var
(
input
=
input
,
axis
=
axis
,
keepdims
=
keepdims
))
return
sqrt
(
var
(
input
=
input
,
axis
=
axis
,
keepdims
=
keepdims
))
...
@@ -5402,7 +5410,6 @@ class Reshape(Op):
...
@@ -5402,7 +5410,6 @@ class Reshape(Op):
raise
ValueError
(
'Cannot reshape input of shape
%
s to shape
%
s'
%
raise
ValueError
(
'Cannot reshape input of shape
%
s to shape
%
s'
%
(
x
.
shape
,
shp
))
(
x
.
shape
,
shp
))
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
x
,
shp
=
inp
x
,
shp
=
inp
g_out
,
=
grads
g_out
,
=
grads
...
@@ -5488,7 +5495,8 @@ class Reshape(Op):
...
@@ -5488,7 +5495,8 @@ class Reshape(Op):
%(shp)
s->data + ii *
%(shp)
s->strides[0]))[0];
%(shp)
s->data + ii *
%(shp)
s->strides[0]))[0];
}
}
Py_XDECREF(
%(z)
s);
Py_XDECREF(
%(z)
s);
%(z)
s = (PyArrayObject *) PyArray_Newshape(
%(x)
s, &newshape, PyArray_CORDER);
%(z)
s = (PyArrayObject *) PyArray_Newshape(
%(x)
s, &newshape,
PyArray_CORDER);
if (!
%(z)
s)
if (!
%(z)
s)
{
{
PyErr_Format(PyExc_ValueError,
PyErr_Format(PyExc_ValueError,
...
@@ -6351,6 +6359,7 @@ advanced_inc_subtensor = AdvancedIncSubtensor()
...
@@ -6351,6 +6359,7 @@ advanced_inc_subtensor = AdvancedIncSubtensor()
#
#
# TODO: Dotinv should go here, Eigs, Svd, etc.
# TODO: Dotinv should go here, Eigs, Svd, etc.
class
Dot
(
Op
):
class
Dot
(
Op
):
"""Compute matrix-matrix, matrix-vector products and vector inner-products.
"""Compute matrix-matrix, matrix-vector products and vector inner-products.
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论