Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
181d5566
提交
181d5566
authored
4月 14, 2022
作者:
Ricardo
提交者:
Brandon T. Willard
4月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow partial shape information in `SpecifyShape` `Op`
Co-authored-by:
Brandon T. Willard
<
971601+brandonwillard@users.noreply.github.com
>
上级
383600bc
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
229 行增加
和
158 行删除
+229
-158
dispatch.py
aesara/link/jax/dispatch.py
+1
-1
basic.py
aesara/link/numba/dispatch/basic.py
+23
-7
basic_opt.py
aesara/tensor/basic_opt.py
+18
-2
shape.py
aesara/tensor/shape.py
+112
-140
subtensor_opt.py
aesara/tensor/subtensor_opt.py
+1
-1
test_jax.py
tests/link/test_jax.py
+1
-1
test_numba.py
tests/link/test_numba.py
+6
-1
test_basic_opt.py
tests/tensor/test_basic_opt.py
+18
-0
test_shape.py
tests/tensor/test_shape.py
+48
-4
test_sharedvar.py
tests/tensor/test_sharedvar.py
+1
-1
没有找到文件。
aesara/link/jax/dispatch.py
浏览文件 @
181d5566
...
@@ -335,7 +335,7 @@ def jax_funcify_Shape_i(op, **kwargs):
...
@@ -335,7 +335,7 @@ def jax_funcify_Shape_i(op, **kwargs):
@jax_funcify.register
(
SpecifyShape
)
@jax_funcify.register
(
SpecifyShape
)
def
jax_funcify_SpecifyShape
(
op
,
**
kwargs
):
def
jax_funcify_SpecifyShape
(
op
,
**
kwargs
):
def
specifyshape
(
x
,
shape
):
def
specifyshape
(
x
,
*
shape
):
assert
x
.
ndim
==
len
(
shape
)
assert
x
.
ndim
==
len
(
shape
)
assert
jnp
.
all
(
x
.
shape
==
tuple
(
shape
)),
(
assert
jnp
.
all
(
x
.
shape
==
tuple
(
shape
)),
(
"got shape"
,
"got shape"
,
...
...
aesara/link/numba/dispatch/basic.py
浏览文件 @
181d5566
...
@@ -2,6 +2,7 @@ import operator
...
@@ -2,6 +2,7 @@ import operator
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
functools
import
singledispatch
from
functools
import
singledispatch
from
textwrap
import
dedent
import
numba
import
numba
import
numba.np.unsafe.ndarray
as
numba_ndarray
import
numba.np.unsafe.ndarray
as
numba_ndarray
...
@@ -40,7 +41,7 @@ from aesara.tensor.subtensor import (
...
@@ -40,7 +41,7 @@ from aesara.tensor.subtensor import (
Subtensor
,
Subtensor
,
)
)
from
aesara.tensor.type
import
TensorType
from
aesara.tensor.type
import
TensorType
from
aesara.tensor.type_other
import
MakeSlice
from
aesara.tensor.type_other
import
MakeSlice
,
NoneConst
def
numba_njit
(
*
args
,
**
kwargs
):
def
numba_njit
(
*
args
,
**
kwargs
):
...
@@ -609,13 +610,28 @@ def numba_funcify_Reshape(op, **kwargs):
...
@@ -609,13 +610,28 @@ def numba_funcify_Reshape(op, **kwargs):
@numba_funcify.register
(
SpecifyShape
)
@numba_funcify.register
(
SpecifyShape
)
def
numba_funcify_SpecifyShape
(
op
,
**
kwargs
):
def
numba_funcify_SpecifyShape
(
op
,
node
,
**
kwargs
):
@numba_njit
shape_inputs
=
node
.
inputs
[
1
:]
def
specifyshape
(
x
,
shape
):
shape_input_names
=
[
"shape_"
+
str
(
i
)
for
i
in
range
(
len
(
shape_inputs
))]
assert
np
.
array_equal
(
x
.
shape
,
shape
)
return
x
func_conditions
=
[
f
"assert x.shape[{i}] == {shape_input_names}"
for
i
,
(
shape_input
,
shape_input_names
)
in
enumerate
(
zip
(
shape_inputs
,
shape_input_names
)
)
if
shape_input
is
not
NoneConst
]
func
=
dedent
(
f
"""
def specify_shape(x, {create_arg_string(shape_input_names)}):
{"; ".join(func_conditions)}
return x
"""
)
return
specifyshape
specify_shape
=
compile_function_src
(
func
,
"specify_shape"
,
globals
())
return
numba_njit
(
specify_shape
)
def
int_to_float_fn
(
inputs
,
out_dtype
):
def
int_to_float_fn
(
inputs
,
out_dtype
):
...
...
aesara/tensor/basic_opt.py
浏览文件 @
181d5566
...
@@ -64,6 +64,7 @@ from aesara.tensor.basic import (
...
@@ -64,6 +64,7 @@ from aesara.tensor.basic import (
join
,
join
,
ones_like
,
ones_like
,
patternbroadcast
,
patternbroadcast
,
stack
,
switch
,
switch
,
tensor_copy
,
tensor_copy
,
unbroadcast
,
unbroadcast
,
...
@@ -75,7 +76,14 @@ from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
...
@@ -75,7 +76,14 @@ from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from
aesara.tensor.extra_ops
import
BroadcastTo
,
Repeat
,
Unique
,
broadcast_shape
from
aesara.tensor.extra_ops
import
BroadcastTo
,
Repeat
,
Unique
,
broadcast_shape
from
aesara.tensor.math
import
all
as
at_all
from
aesara.tensor.math
import
all
as
at_all
from
aesara.tensor.math
import
eq
from
aesara.tensor.math
import
eq
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
,
shape_padleft
from
aesara.tensor.shape
import
(
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
,
shape_i
,
shape_padleft
,
)
from
aesara.tensor.sort
import
TopKOp
from
aesara.tensor.sort
import
TopKOp
from
aesara.tensor.subtensor
import
Subtensor
,
get_idx_list
from
aesara.tensor.subtensor
import
Subtensor
,
get_idx_list
from
aesara.tensor.type
import
(
from
aesara.tensor.type
import
(
...
@@ -84,6 +92,7 @@ from aesara.tensor.type import (
...
@@ -84,6 +92,7 @@ from aesara.tensor.type import (
discrete_dtypes
,
discrete_dtypes
,
integer_dtypes
,
integer_dtypes
,
)
)
from
aesara.tensor.type_other
import
NoneConst
from
aesara.tensor.var
import
TensorConstant
from
aesara.tensor.var
import
TensorConstant
from
aesara.utils
import
NoDuplicateOptWarningFilter
from
aesara.utils
import
NoDuplicateOptWarningFilter
...
@@ -3521,7 +3530,14 @@ def local_Shape_of_SpecifyShape(fgraph, node):
...
@@ -3521,7 +3530,14 @@ def local_Shape_of_SpecifyShape(fgraph, node):
if
not
isinstance
(
getattr
(
specified_shape
.
owner
,
"op"
,
None
),
SpecifyShape
):
if
not
isinstance
(
getattr
(
specified_shape
.
owner
,
"op"
,
None
),
SpecifyShape
):
return
False
return
False
return
[
specified_shape
.
owner
.
inputs
[
1
]
.
astype
(
np
.
int64
)]
x
,
*
shape
=
specified_shape
.
owner
.
inputs
# Replace `NoneConst` by `shape_i`
for
i
,
sh
in
enumerate
(
shape
):
if
NoneConst
.
equals
(
sh
):
shape
[
i
]
=
shape_i
(
x
,
i
,
fgraph
)
return
[
stack
(
shape
)
.
astype
(
np
.
int64
)]
@register_useless
@register_useless
...
...
aesara/tensor/shape.py
浏览文件 @
181d5566
import
warnings
import
warnings
from
numbers
import
Number
from
numbers
import
Number
from
textwrap
import
dedent
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
aesara
import
aesara
from
aesara.gradient
import
DisconnectedType
from
aesara.gradient
import
DisconnectedType
from
aesara.graph.basic
import
Apply
,
Constant
,
Variable
from
aesara.graph.basic
import
Apply
,
Variable
from
aesara.link.c.op
import
COp
from
aesara.link.c.op
import
COp
from
aesara.link.c.params_type
import
ParamsType
from
aesara.link.c.params_type
import
ParamsType
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.misc.safe_asarray
import
_asarray
...
@@ -15,7 +16,8 @@ from aesara.tensor import _get_vector_length
...
@@ -15,7 +16,8 @@ from aesara.tensor import _get_vector_length
from
aesara.tensor
import
basic
as
at
from
aesara.tensor
import
basic
as
at
from
aesara.tensor
import
get_vector_length
from
aesara.tensor
import
get_vector_length
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.type
import
TensorType
,
int_dtypes
,
tensor
from
aesara.tensor.type
import
DenseTensorType
,
TensorType
,
int_dtypes
,
tensor
from
aesara.tensor.type_other
import
NoneConst
from
aesara.tensor.var
import
TensorConstant
,
TensorVariable
from
aesara.tensor.var
import
TensorConstant
,
TensorVariable
...
@@ -362,28 +364,6 @@ def register_shape_i_c_code(typ, code, check_input, version=()):
...
@@ -362,28 +364,6 @@ def register_shape_i_c_code(typ, code, check_input, version=()):
Shape_i
.
c_code_and_version
[
typ
]
=
(
code
,
check_input
,
version
)
Shape_i
.
c_code_and_version
[
typ
]
=
(
code
,
check_input
,
version
)
def
register_specify_shape_c_code
(
typ
,
code
,
version
=
(),
c_support_code_apply
=
None
):
"""
Tell SpecifyShape how to generate C code for an Aesara Type.
Parameters
----------
typ : Aesara type
It must be the Aesara class itself and not an instance of the class.
code : C code
Checks the shape and returns a view for the Aesara type 'typ'.
Use
%(iname)
s and
%(oname)
s for the input and output C variable names
respectively.
%(shape)
s is the vector of shape of
%(iname)
s.
Check that its length is good.
version
A number indicating the version of the code, for cache.
c_support_code_apply
Extra code.
"""
SpecifyShape
.
c_code_and_version
[
typ
]
=
(
code
,
version
,
c_support_code_apply
)
class
SpecifyShape
(
COp
):
class
SpecifyShape
(
COp
):
"""
"""
L{Op} that puts into the graph the user-provided shape.
L{Op} that puts into the graph the user-provided shape.
...
@@ -396,33 +376,29 @@ class SpecifyShape(COp):
...
@@ -396,33 +376,29 @@ class SpecifyShape(COp):
Notes
Notes
-----
-----
Maybe in the future we will never do the assert!
Maybe in the future we will never do the assert!
We currently don't support specifying partial shape information.
TODO : test this op with sparse. Do C code for them too.
"""
"""
view_map
=
{
0
:
[
0
]}
view_map
=
{
0
:
[
0
]}
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version
:
Dict
=
{}
__props__
=
()
__props__
=
()
_f16_ok
=
True
_f16_ok
=
True
def
make_node
(
self
,
x
,
shape
):
def
make_node
(
self
,
x
,
*
shape
):
if
not
isinstance
(
x
,
Variable
):
from
aesara.tensor.basic
import
get_scalar_constant_value
x
=
at
.
as_tensor_variable
(
x
)
shape
=
at
.
as_tensor_variable
(
shape
,
ndim
=
1
)
x
=
at
.
as_tensor_variable
(
x
)
if
isinstance
(
shape
,
Constant
):
shape
=
tuple
(
shape
=
tuple
(
shape
.
data
)
NoneConst
else
:
if
(
s
is
None
or
NoneConst
.
equals
(
s
))
shape
=
tuple
(
at
.
as_tensor_variable
(
s
,
ndim
=
0
)
for
s
in
shape
)
else
at
.
as_tensor_variable
(
s
,
ndim
=
0
)
for
s
in
shape
)
if
any
(
s
.
dtype
not
in
aesara
.
tensor
.
type
.
integer_dtypes
for
s
in
shape
):
if
any
(
s
.
dtype
not
in
aesara
.
tensor
.
type
.
integer_dtypes
for
s
in
shape
if
hasattr
(
s
,
"dtype"
)
):
raise
TypeError
(
"Shape values must be integer types"
)
raise
TypeError
(
"Shape values must be integer types"
)
if
len
(
shape
)
!=
x
.
type
.
ndim
:
if
len
(
shape
)
!=
x
.
type
.
ndim
:
...
@@ -430,102 +406,127 @@ class SpecifyShape(COp):
...
@@ -430,102 +406,127 @@ class SpecifyShape(COp):
f
"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}."
f
"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}."
)
)
if
isinstance
(
x
.
type
,
TensorType
)
and
all
(
isinstance
(
s
,
Number
)
for
s
in
shape
):
type_shape
=
[
None
]
*
x
.
ndim
out_var
=
x
.
type
.
clone
(
shape
=
shape
)()
for
i
,
(
xts
,
s
)
in
enumerate
(
zip
(
x
.
type
.
shape
,
shape
)):
else
:
if
xts
is
not
None
:
out_var
=
x
.
type
()
type_shape
[
i
]
=
xts
else
:
try
:
type_s
=
get_scalar_constant_value
(
s
)
if
type_s
is
not
None
:
type_shape
[
i
]
=
int
(
type_s
)
except
NotScalarConstantError
:
pass
out_var
=
x
.
type
.
clone
(
shape
=
type_shape
)()
in_shape
=
at
.
as_tensor_variable
(
shape
,
ndim
=
1
)
return
Apply
(
self
,
[
x
,
*
shape
],
[
out_var
])
return
Apply
(
self
,
[
x
,
in_shape
],
[
out_var
])
def
perform
(
self
,
node
,
inp
,
out_
):
def
perform
(
self
,
node
,
inp
,
out_
):
x
,
shape
=
inp
x
,
*
shape
=
inp
(
out
,)
=
out_
(
out
,)
=
out_
ndim
=
len
(
shape
)
ndim
=
len
(
shape
)
if
x
.
ndim
!=
ndim
:
if
x
.
ndim
!=
ndim
:
raise
AssertionError
(
raise
AssertionError
(
f
"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
f
"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
)
if
x
.
shape
!=
tuple
(
shap
e
):
if
not
all
(
xs
==
s
for
xs
,
s
in
zip
(
x
.
shape
,
shape
)
if
s
is
not
Non
e
):
raise
AssertionError
(
raise
AssertionError
(
f
"SpecifyShape: Got shape {x.shape}, expected {tuple(shape)}."
f
"SpecifyShape: Got shape {x.shape}, expected {tuple(
int(s) if s is not None else None for s in
shape)}."
)
)
out
[
0
]
=
x
out
[
0
]
=
x
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
xshape
,
sshape
=
shapes
xshape
,
*
_
=
shapes
shape
=
node
.
inputs
[
1
:]
new_shape
=
[]
new_shape
=
[]
for
dim
in
range
(
node
.
inputs
[
0
]
.
type
.
ndim
):
for
dim
in
range
(
node
.
inputs
[
0
]
.
type
.
ndim
):
s
=
shape
[
dim
]
try
:
try
:
s
=
at
.
get_scalar_constant_value
(
node
.
inputs
[
1
][
dim
])
s
=
at
.
get_scalar_constant_value
(
s
)
s
=
at
.
as_tensor_variable
(
s
)
# We assume that `None` shapes are always retrieved by
new_shape
.
append
(
s
)
# `get_scalar_constant_value`, and only in that case do we default to
# the shape of the input variable
if
s
is
None
:
s
=
xshape
[
dim
]
except
NotScalarConstantError
:
except
NotScalarConstantError
:
new_shape
.
append
(
node
.
inputs
[
1
][
dim
])
pass
new_shape
.
append
(
at
.
as_tensor_variable
(
s
))
assert
len
(
new_shape
)
==
len
(
xshape
)
assert
len
(
new_shape
)
==
len
(
xshape
)
return
[
new_shape
]
return
[
new_shape
]
def
connection_pattern
(
self
,
node
):
def
connection_pattern
(
self
,
node
):
return
[[
True
],
[
False
]
]
return
[[
True
],
*
[[
False
]]
*
len
(
node
.
inputs
[
1
:])
]
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
x
,
s
=
inp
x
,
*
shape
=
inp
(
gz
,)
=
grads
(
gz
,)
=
grads
# Should I set an SpecifyShape on gz? I think so
# Should I set an SpecifyShape on gz? I think so
# But I don't do it now as we need to make an optimization
# But I don't do it now as we need to make an optimization
# to remove that op from the graph to don't block other optimization
# to remove that op from the graph to don't block other optimization
# Should I do an optimizer that will remove the SpecifyShape?
# Should I do an optimizer that will remove the SpecifyShape?
# I think Yes
# I think Yes
return
[
gz
,
aesara
.
gradient
.
DisconnectedType
()(
)]
# return [specify_shape(gz, s)] + [aesara.gradient.DisconnectedType()() for _ in range(len(shape)
)]
return
[
specify_shape
(
gz
,
s
),
aesara
.
gradient
.
DisconnectedType
()(
)]
return
[
gz
]
+
[
aesara
.
gradient
.
DisconnectedType
()()
for
_
in
range
(
len
(
shape
)
)]
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
if
eval_points
[
0
]
is
None
:
if
eval_points
[
0
]
is
None
:
# It means that the this op sits on top of a non-differentiable
# It means that this op sits on top of a non-differentiable path
# path
return
[
None
]
return
[
None
]
return
self
.
make_node
(
eval_points
[
0
],
*
inputs
[
1
:])
.
outputs
return
self
.
make_node
(
eval_points
[
0
],
*
inputs
[
1
:])
.
outputs
def
c_support_code_apply
(
self
,
node
,
name
):
def
c_code
(
self
,
node
,
name
,
i_names
,
o_names
,
sub
):
itype
=
node
.
inputs
[
0
]
.
type
.
__class__
if
not
isinstance
(
node
.
inputs
[
0
]
.
type
,
DenseTensorType
):
if
itype
in
self
.
c_code_and_version
:
raise
NotImplementedError
(
_
,
_
,
support_code
=
self
.
c_code_and_version
[
itype
]
f
"Specify_shape c_code not implemented for input type {node.inputs[0].type}"
if
support_code
:
)
return
support_code
return
super
()
.
c_support_code_apply
(
node
,
name
)
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
x_name
,
*
shape_names
=
i_names
iname
,
shape
=
inames
(
o_name
,)
=
o_names
(
oname
,)
=
onames
fail
=
sub
[
"fail"
]
fail
=
sub
[
"fail"
]
itype
=
node
.
inputs
[
0
]
.
type
.
__class__
code
=
dedent
(
if
itype
in
self
.
c_code_and_version
:
f
"""
code
,
version
,
_
=
self
.
c_code_and_version
[
itype
]
if (PyArray_NDIM({x_name}) != {len(shape_names)}) {{
return
code
%
locals
()
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: Got
%
d dimensions, expected
%
d dimensions.",
PyArray_NDIM({x_name}), {len(shape_names)}
);
{fail};
}}
"""
)
raise
NotImplementedError
()
for
i
,
(
shp_name
,
shp
)
in
enumerate
(
zip
(
shape_names
,
node
.
inputs
[
1
:])):
if
NoneConst
.
equals
(
shp
):
continue
code
+=
dedent
(
f
"""
if (py_{shp_name} != Py_None){{
dtype_{shp_name} shp = ((dtype_{shp_name}*)PyArray_GETPTR1({shp_name}, 0))[0];
if (PyArray_DIMS({x_name})[{i}] != shp) {{
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: dim
%
d of input has shape
%
d, expected
%
d.",
{i}, PyArray_DIMS({x_name})[{i}], shp
);
{fail};
}}
}}
"""
)
def
c_code_cache_version
(
self
):
code
+=
dedent
(
version
=
[]
f
"""
# If any of the c code is unversioned, we have to return ()
Py_XDECREF({o_name});
# Else, we will return a list of (type name, version) pairs.
{o_name} = {x_name};
for
t
,
(
c
,
v
,
_
)
in
sorted
(
Py_XINCREF({o_name});
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])
"""
):
)
if
not
v
:
return
code
warnings
.
warn
(
"Type
%
s has C code for SpecifyShape, but it "
"has no version. You should add a 'version' "
"keyword arg when calling "
"register_specify_shape_c_code."
%
t
,
stacklevel
=
2
,
)
return
()
version
.
append
((
str
(
t
),
v
))
return
tuple
(
version
)
def
c_code_cache_version
(
self
):
return
(
2
,)
_specify_shape
=
SpecifyShape
()
_specify_shape
=
SpecifyShape
()
...
@@ -537,29 +538,31 @@ def specify_shape(
...
@@ -537,29 +538,31 @@ def specify_shape(
int
,
List
[
Union
[
int
,
Variable
]],
Tuple
[
Union
[
int
,
Variable
]],
Variable
int
,
List
[
Union
[
int
,
Variable
]],
Tuple
[
Union
[
int
,
Variable
]],
Variable
],
],
):
):
"""Specify a fixed shape for a `Variable`.
"""
"""Specify a fixed shape for a `Variable`.
if
not
isinstance
(
x
,
Variable
):
If a dimension's shape value is ``None``, the size of that dimension is not considered fixed/static at runtime.
x
=
at
.
as_tensor_variable
(
x
)
"""
if
np
.
ndim
(
shape
)
==
0
:
shape
=
at
.
as_tensor_variable
([
shape
])
try
:
if
not
isinstance
(
shape
,
(
tuple
,
list
)):
_
=
get_vector_length
(
shape
)
shape
=
(
shape
,)
except
ValueError
:
raise
ValueError
(
"Shape must have fixed dimensions"
)
if
isinstance
(
shape
,
Constant
):
# If shape is a symbolic 1d vector of fixed length, we separate the items into a
shape
=
tuple
(
shape
.
data
)
# tuple with one entry per shape dimension
if
len
(
shape
)
==
1
and
shape
[
0
]
is
not
None
:
shape_vector
=
at
.
as_tensor_variable
(
shape
[
0
])
if
shape_vector
.
ndim
==
1
:
try
:
shape
=
tuple
(
shape_vector
)
except
ValueError
:
raise
ValueError
(
"Shape vector must have fixed dimensions"
)
return
_specify_shape
(
x
,
shape
)
return
_specify_shape
(
x
,
*
shape
)
@_get_vector_length.register
(
SpecifyShape
)
@_get_vector_length.register
(
SpecifyShape
)
def
_get_vector_length_SpecifyShape
(
op
,
var
):
def
_get_vector_length_SpecifyShape
(
op
,
var
):
try
:
try
:
return
at
.
get_scalar_constant_value
(
var
.
owner
.
inputs
[
1
])
return
at
.
get_scalar_constant_value
(
var
.
owner
.
inputs
[
1
])
.
item
()
except
NotScalarConstantError
:
except
NotScalarConstantError
:
raise
ValueError
(
f
"Length of {var} cannot be determined"
)
raise
ValueError
(
f
"Length of {var} cannot be determined"
)
...
@@ -882,34 +885,3 @@ register_shape_i_c_code(
...
@@ -882,34 +885,3 @@ register_shape_i_c_code(
"""
,
"""
,
version
=
3
,
version
=
3
,
)
)
register_specify_shape_c_code
(
TensorType
,
"""
if (PyArray_NDIM(
%(iname)
s) != PyArray_DIMS(
%(shape)
s)[0]) {
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: Got
%%
d dimensions, expected
%%
d dimensions.",
PyArray_NDIM(
%(iname)
s),
PyArray_DIMS(
%(shape)
s)[0]
);
%(fail)
s;
}
for(int i = 0; i < PyArray_NDIM(
%(iname)
s); i++){
dtype_
%(shape)
s shp = ((dtype_
%(shape)
s*)PyArray_GETPTR1(
%(shape)
s,
i))[0];
if (PyArray_DIMS(
%(iname)
s)[i] != shp) {
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: dim
%%
d of input has shape
%%
d,"
" expected
%%
d.",
i, PyArray_DIMS(
%(iname)
s)[i],
shp);
%(fail)
s;
}
}
Py_XDECREF(
%(oname)
s);
%(oname)
s =
%(iname)
s;
Py_XINCREF(
%(oname)
s);
"""
,
version
=
1
,
)
aesara/tensor/subtensor_opt.py
浏览文件 @
181d5566
...
@@ -1646,7 +1646,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
...
@@ -1646,7 +1646,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
return
False
return
False
obj_arg
=
specify_shape_node
.
owner
.
inputs
[
0
]
obj_arg
=
specify_shape_node
.
owner
.
inputs
[
0
]
shape_arg
=
specify_shape_node
.
owner
.
inputs
[
1
]
shape_arg
=
specify_shape_node
.
owner
.
inputs
[
1
:
]
indices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
indices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
...
...
tests/link/test_jax.py
浏览文件 @
181d5566
...
@@ -185,7 +185,7 @@ def test_jax_specify_shape():
...
@@ -185,7 +185,7 @@ def test_jax_specify_shape():
with
config
.
change_flags
(
compute_test_value
=
"off"
):
with
config
.
change_flags
(
compute_test_value
=
"off"
):
x
=
SpecifyShape
()(
at
.
as_tensor_variable
(
x_np
),
(
2
,
3
))
x
=
SpecifyShape
()(
at
.
as_tensor_variable
(
x_np
),
*
(
2
,
3
))
x_fg
=
FunctionGraph
([],
[
x
])
x_fg
=
FunctionGraph
([],
[
x
])
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
...
...
tests/link/test_numba.py
浏览文件 @
181d5566
...
@@ -896,10 +896,15 @@ def test_Reshape_scalar():
...
@@ -896,10 +896,15 @@ def test_Reshape_scalar():
(
1
,
1
),
(
1
,
1
),
True
,
True
,
),
),
(
set_test_value
(
at
.
matrix
(),
np
.
array
([[
1.0
,
2.0
]],
dtype
=
config
.
floatX
)),
(
1
,
None
),
False
,
),
],
],
)
)
def
test_SpecifyShape
(
v
,
shape
,
fails
):
def
test_SpecifyShape
(
v
,
shape
,
fails
):
g
=
SpecifyShape
()(
v
,
shape
)
g
=
SpecifyShape
()(
v
,
*
shape
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
not
fails
else
pytest
.
raises
(
AssertionError
)
cm
=
contextlib
.
suppress
()
if
not
fails
else
pytest
.
raises
(
AssertionError
)
with
cm
:
with
cm
:
...
...
tests/tensor/test_basic_opt.py
浏览文件 @
181d5566
...
@@ -2975,6 +2975,24 @@ def test_local_Shape_of_SpecifyShape(shape):
...
@@ -2975,6 +2975,24 @@ def test_local_Shape_of_SpecifyShape(shape):
assert
shape
in
fgraph
.
variables
assert
shape
in
fgraph
.
variables
@pytest.mark.parametrize
(
"s1"
,
[
lscalar
(),
iscalar
()],
)
def
test_local_Shape_of_SpecifyShape_partial
(
s1
):
x
=
matrix
()
s
=
specify_shape
(
x
,
(
s1
,
None
))
.
shape
fgraph
=
FunctionGraph
(
outputs
=
[
s
],
clone
=
False
)
assert
any
(
isinstance
(
apply
.
op
,
SpecifyShape
)
for
apply
in
fgraph
.
apply_nodes
)
_
=
optimize_graph
(
fgraph
,
clone
=
False
)
assert
x
in
fgraph
.
variables
assert
s1
in
fgraph
.
variables
assert
not
any
(
isinstance
(
apply
.
op
,
SpecifyShape
)
for
apply
in
fgraph
.
apply_nodes
)
def
test_local_Shape_i_of_broadcastable
():
def
test_local_Shape_i_of_broadcastable
():
x
=
tensor
(
np
.
float64
,
[
False
,
True
])
x
=
tensor
(
np
.
float64
,
[
False
,
True
])
s
=
Shape_i
(
1
)(
x
)
s
=
Shape_i
(
1
)(
x
)
...
...
tests/tensor/test_shape.py
浏览文件 @
181d5566
...
@@ -344,13 +344,13 @@ class TestSpecifyShape(utt.InferShapeTester):
...
@@ -344,13 +344,13 @@ class TestSpecifyShape(utt.InferShapeTester):
specify_shape
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
(
2.2
,
3
))
specify_shape
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
(
2.2
,
3
))
with
pytest
.
raises
(
TypeError
,
match
=
"must be integer types"
):
with
pytest
.
raises
(
TypeError
,
match
=
"must be integer types"
):
_specify_shape
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
(
2.2
,
3
))
_specify_shape
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
*
(
2.2
,
3
))
with
pytest
.
raises
(
ValueError
,
match
=
"will never match"
):
with
pytest
.
raises
(
ValueError
,
match
=
"will never match"
):
specify_shape
(
matrix
(),
[
4
])
specify_shape
(
matrix
(),
[
4
])
with
pytest
.
raises
(
ValueError
,
match
=
"will never match"
):
with
pytest
.
raises
(
ValueError
,
match
=
"will never match"
):
_specify_shape
(
matrix
(),
[
4
])
_specify_shape
(
matrix
(),
*
[
4
])
with
pytest
.
raises
(
ValueError
,
match
=
"must have fixed dimensions"
):
with
pytest
.
raises
(
ValueError
,
match
=
"must have fixed dimensions"
):
specify_shape
(
matrix
(),
vector
(
dtype
=
"int32"
))
specify_shape
(
matrix
(),
vector
(
dtype
=
"int32"
))
...
@@ -378,6 +378,14 @@ class TestSpecifyShape(utt.InferShapeTester):
...
@@ -378,6 +378,14 @@ class TestSpecifyShape(utt.InferShapeTester):
f
=
aesara
.
function
([
x
],
y
,
mode
=
self
.
mode
)
f
=
aesara
.
function
([
x
],
y
,
mode
=
self
.
mode
)
assert
f
([
15
])
==
[
15
]
assert
f
([
15
])
==
[
15
]
def
test_partial_shapes
(
self
):
x
=
matrix
()
s1
=
lscalar
()
y
=
specify_shape
(
x
,
(
s1
,
None
))
f
=
aesara
.
function
([
x
,
s1
],
y
,
mode
=
self
.
mode
)
assert
f
(
np
.
zeros
((
2
,
5
),
dtype
=
config
.
floatX
),
2
)
.
shape
==
(
2
,
5
)
assert
f
(
np
.
zeros
((
3
,
5
),
dtype
=
config
.
floatX
),
3
)
.
shape
==
(
3
,
5
)
def
test_fixed_shapes
(
self
):
def
test_fixed_shapes
(
self
):
x
=
vector
()
x
=
vector
()
shape
=
as_tensor_variable
([
2
])
shape
=
as_tensor_variable
([
2
])
...
@@ -385,6 +393,15 @@ class TestSpecifyShape(utt.InferShapeTester):
...
@@ -385,6 +393,15 @@ class TestSpecifyShape(utt.InferShapeTester):
assert
y
.
type
.
shape
==
(
2
,)
assert
y
.
type
.
shape
==
(
2
,)
assert
y
.
shape
.
equals
(
shape
)
assert
y
.
shape
.
equals
(
shape
)
def
test_fixed_partial_shapes
(
self
):
x
=
TensorType
(
"floatX"
,
(
None
,
None
))(
"x"
)
y
=
specify_shape
(
x
,
(
None
,
5
))
assert
y
.
type
.
shape
==
(
None
,
5
)
x
=
TensorType
(
"floatX"
,
(
3
,
None
))(
"x"
)
y
=
specify_shape
(
x
,
(
None
,
5
))
assert
y
.
type
.
shape
==
(
3
,
5
)
def
test_python_perform
(
self
):
def
test_python_perform
(
self
):
"""Test the Python `Op.perform` implementation."""
"""Test the Python `Op.perform` implementation."""
x
=
scalar
()
x
=
scalar
()
...
@@ -403,13 +420,20 @@ class TestSpecifyShape(utt.InferShapeTester):
...
@@ -403,13 +420,20 @@ class TestSpecifyShape(utt.InferShapeTester):
with
pytest
.
raises
(
AssertionError
,
match
=
"SpecifyShape:.*"
):
with
pytest
.
raises
(
AssertionError
,
match
=
"SpecifyShape:.*"
):
assert
f
([
1
],
(
2
,))
==
[
1
]
assert
f
([
1
],
(
2
,))
==
[
1
]
x
=
matrix
()
y
=
specify_shape
(
x
,
(
None
,
2
))
f
=
aesara
.
function
([
x
],
y
,
mode
=
Mode
(
"py"
))
assert
f
(
np
.
zeros
((
3
,
2
),
dtype
=
config
.
floatX
))
.
shape
==
(
3
,
2
)
with
pytest
.
raises
(
AssertionError
,
match
=
"SpecifyShape:.*"
):
assert
f
(
np
.
zeros
((
3
,
3
),
dtype
=
config
.
floatX
))
def
test_bad_shape
(
self
):
def
test_bad_shape
(
self
):
"""Test that at run-time we raise an exception when the shape is not the one specified."""
"""Test that at run-time we raise an exception when the shape is not the one specified."""
specify_shape
=
SpecifyShape
()
specify_shape
=
SpecifyShape
()
x
=
vector
()
x
=
vector
()
xval
=
np
.
random
.
random
((
2
))
.
astype
(
config
.
floatX
)
xval
=
np
.
random
.
random
((
2
))
.
astype
(
config
.
floatX
)
f
=
aesara
.
function
([
x
],
specify_shape
(
x
,
[
2
]
),
mode
=
self
.
mode
)
f
=
aesara
.
function
([
x
],
specify_shape
(
x
,
2
),
mode
=
self
.
mode
)
assert
np
.
array_equal
(
f
(
xval
),
xval
)
assert
np
.
array_equal
(
f
(
xval
),
xval
)
...
@@ -426,7 +450,7 @@ class TestSpecifyShape(utt.InferShapeTester):
...
@@ -426,7 +450,7 @@ class TestSpecifyShape(utt.InferShapeTester):
x
=
matrix
()
x
=
matrix
()
xval
=
np
.
random
.
random
((
2
,
3
))
.
astype
(
config
.
floatX
)
xval
=
np
.
random
.
random
((
2
,
3
))
.
astype
(
config
.
floatX
)
f
=
aesara
.
function
([
x
],
specify_shape
(
x
,
[
2
,
3
]
),
mode
=
self
.
mode
)
f
=
aesara
.
function
([
x
],
specify_shape
(
x
,
2
,
3
),
mode
=
self
.
mode
)
assert
isinstance
(
assert
isinstance
(
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
n
.
op
,
SpecifyShape
)][
0
]
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
n
.
op
,
SpecifyShape
)][
0
]
.
inputs
[
0
]
.
inputs
[
0
]
...
@@ -441,6 +465,13 @@ class TestSpecifyShape(utt.InferShapeTester):
...
@@ -441,6 +465,13 @@ class TestSpecifyShape(utt.InferShapeTester):
with
pytest
.
raises
(
AssertionError
,
match
=
"SpecifyShape:.*"
):
with
pytest
.
raises
(
AssertionError
,
match
=
"SpecifyShape:.*"
):
f
(
xval
)
f
(
xval
)
s
=
iscalar
(
"s"
)
f
=
aesara
.
function
([
x
,
s
],
specify_shape
(
x
,
None
,
s
),
mode
=
self
.
mode
)
x_val
=
np
.
zeros
((
3
,
2
),
dtype
=
config
.
floatX
)
assert
f
(
x_val
,
2
)
.
shape
==
(
3
,
2
)
with
pytest
.
raises
(
AssertionError
,
match
=
"SpecifyShape:.*"
):
f
(
xval
,
3
)
def
test_infer_shape
(
self
):
def
test_infer_shape
(
self
):
rng
=
np
.
random
.
default_rng
(
3453
)
rng
=
np
.
random
.
default_rng
(
3453
)
adtens4
=
dtensor4
()
adtens4
=
dtensor4
()
...
@@ -454,6 +485,19 @@ class TestSpecifyShape(utt.InferShapeTester):
...
@@ -454,6 +485,19 @@ class TestSpecifyShape(utt.InferShapeTester):
SpecifyShape
,
SpecifyShape
,
)
)
def
test_infer_shape_partial
(
self
):
rng
=
np
.
random
.
default_rng
(
3453
)
adtens4
=
dtensor4
()
aivec
=
[
iscalar
(),
iscalar
(),
None
,
iscalar
()]
aivec_val
=
[
3
,
4
,
5
]
adtens4_val
=
rng
.
random
((
3
,
4
,
2
,
5
))
self
.
_compile_and_check
(
[
adtens4
,
*
(
ivec
for
ivec
in
aivec
if
ivec
is
not
None
)],
[
specify_shape
(
adtens4
,
aivec
)],
[
adtens4_val
,
*
aivec_val
],
SpecifyShape
,
)
class
TestRopLop
(
RopLopChecker
):
class
TestRopLop
(
RopLopChecker
):
def
test_shape
(
self
):
def
test_shape
(
self
):
...
...
tests/tensor/test_sharedvar.py
浏览文件 @
181d5566
...
@@ -474,7 +474,7 @@ def makeSharedTester(
...
@@ -474,7 +474,7 @@ def makeSharedTester(
assert
np
.
all
(
self
.
ref_fct
(
specify_shape_fct
())
==
self
.
ref_fct
(
x1_2
))
assert
np
.
all
(
self
.
ref_fct
(
specify_shape_fct
())
==
self
.
ref_fct
(
x1_2
))
topo_specify
=
specify_shape_fct
.
maker
.
fgraph
.
toposort
()
topo_specify
=
specify_shape_fct
.
maker
.
fgraph
.
toposort
()
if
aesara
.
config
.
mode
!=
"FAST_COMPILE"
:
if
aesara
.
config
.
mode
!=
"FAST_COMPILE"
:
assert
len
(
topo_specify
)
==
4
assert
len
(
topo_specify
)
==
3
# Test that we put the shape info into the graph
# Test that we put the shape info into the graph
shape_constant_fct
=
aesara
.
function
([],
x1_specify_shape
.
shape
)
shape_constant_fct
=
aesara
.
function
([],
x1_specify_shape
.
shape
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论