Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7fa753b5
提交
7fa753b5
authored
3月 05, 2026
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
3月 06, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Improve static shape of shaped xtensor Ops
Also avoid useless cast to XTensorType when 0d-tensor suffices
上级
d38dc060
显示空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
62 行增加
和
28 行删除
+62
-28
basic.py
pytensor/tensor/basic.py
+22
-18
subtensor.py
pytensor/tensor/subtensor.py
+2
-2
shape.py
pytensor/xtensor/rewriting/shape.py
+1
-2
shape.py
pytensor/xtensor/shape.py
+5
-2
vectorization.py
pytensor/xtensor/vectorization.py
+2
-1
test_basic.py
tests/tensor/test_basic.py
+1
-2
test_random.py
tests/xtensor/test_random.py
+8
-0
test_shape.py
tests/xtensor/test_shape.py
+21
-1
没有找到文件。
pytensor/tensor/basic.py
浏览文件 @
7fa753b5
...
@@ -29,13 +29,13 @@ from pytensor.graph.fg import FunctionGraph, Output
...
@@ -29,13 +29,13 @@ from pytensor.graph.fg import FunctionGraph, Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.rewriting.db
import
EquilibriumDB
from
pytensor.graph.rewriting.db
import
EquilibriumDB
from
pytensor.graph.type
import
HasShape
,
Type
from
pytensor.graph.type
import
Has
DataType
,
Has
Shape
,
Type
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.params_type
import
ParamsType
from
pytensor.link.c.params_type
import
ParamsType
from
pytensor.printing
import
Printer
,
min_informative_str
,
pprint
,
set_precedence
from
pytensor.printing
import
Printer
,
min_informative_str
,
pprint
,
set_precedence
from
pytensor.raise_op
import
CheckAndRaise
from
pytensor.raise_op
import
CheckAndRaise
from
pytensor.scalar
import
int32
from
pytensor.scalar
import
int32
from
pytensor.scalar.basic
import
ScalarConstant
,
Scalar
Type
,
Scalar
Variable
from
pytensor.scalar.basic
import
ScalarConstant
,
ScalarVariable
from
pytensor.tensor
import
(
from
pytensor.tensor
import
(
_as_tensor_variable
,
_as_tensor_variable
,
_get_vector_length
,
_get_vector_length
,
...
@@ -292,13 +292,8 @@ def _get_underlying_scalar_constant_value(
...
@@ -292,13 +292,8 @@ def _get_underlying_scalar_constant_value(
max_recur : int
max_recur : int
The maximum number of recursion.
The maximum number of recursion.
Notes
-----
There may be another function similar to this one in the code,
but I'm not sure where it is.
"""
"""
from
pytensor.compile.ops
import
DeepCopyOp
,
OutputGuard
from
pytensor.compile.ops
import
DeepCopyOp
,
OutputGuard
,
TypeCastingOp
from
pytensor.sparse
import
CSM
from
pytensor.sparse
import
CSM
from
pytensor.tensor.subtensor
import
Subtensor
from
pytensor.tensor.subtensor
import
Subtensor
...
@@ -319,13 +314,20 @@ def _get_underlying_scalar_constant_value(
...
@@ -319,13 +314,20 @@ def _get_underlying_scalar_constant_value(
raise
NotScalarConstantError
()
raise
NotScalarConstantError
()
if
isinstance
(
v
,
Constant
):
if
isinstance
(
v
,
Constant
):
if
isinstance
(
v
.
type
,
TensorType
)
and
v
.
unique_value
is
not
None
:
v_type
=
v
.
type
return
v
.
unique_value
if
isinstance
(
v_type
,
HasShape
)
and
isinstance
(
v_type
,
HasDataType
):
if
v_type
.
ndim
==
0
:
return
np
.
array
(
v
.
data
,
dtype
=
v
.
type
.
dtype
)
elif
isinstance
(
v
.
type
,
ScalarType
):
elif
(
not
any
(
s
is
None
for
s
in
v_type
.
shape
))
and
(
return
v
.
data
np
.
prod
(
v_type
.
shape
)
==
1
):
return
np
.
array
(
v
.
data
,
dtype
=
v_type
.
dtype
)
.
squeeze
()
elif
isinstance
(
v
.
type
,
NoneTypeT
):
elif
isinstance
(
v_type
,
TensorType
)
and
v
.
unique_value
is
not
None
:
return
np
.
array
(
v
.
unique_value
)
elif
isinstance
(
v_type
,
NoneTypeT
):
return
None
return
None
raise
NotScalarConstantError
()
raise
NotScalarConstantError
()
...
@@ -333,9 +335,9 @@ def _get_underlying_scalar_constant_value(
...
@@ -333,9 +335,9 @@ def _get_underlying_scalar_constant_value(
if
not
only_process_constants
and
getattr
(
v
,
"owner"
,
None
)
and
max_recur
>
0
:
if
not
only_process_constants
and
getattr
(
v
,
"owner"
,
None
)
and
max_recur
>
0
:
op
=
v
.
owner
.
op
op
=
v
.
owner
.
op
max_recur
-=
1
max_recur
-=
1
if
isinstance
(
op
,
Alloc
|
DimShuffle
|
OutputGuard
|
DeepCopyOp
):
if
isinstance
(
# OutputGuard is only used in debugmode but we
op
,
Alloc
|
DimShuffle
|
TypeCastingOp
|
DeepCopyOp
|
OutputGuard
# keep it here to avoid problems with old pickles
):
v
=
v
.
owner
.
inputs
[
0
]
v
=
v
.
owner
.
inputs
[
0
]
continue
continue
elif
isinstance
(
op
,
Shape_i
):
elif
isinstance
(
op
,
Shape_i
):
...
@@ -343,7 +345,6 @@ def _get_underlying_scalar_constant_value(
...
@@ -343,7 +345,6 @@ def _get_underlying_scalar_constant_value(
inp
=
v
.
owner
.
inputs
[
0
]
inp
=
v
.
owner
.
inputs
[
0
]
if
isinstance
(
inp
,
Constant
):
if
isinstance
(
inp
,
Constant
):
return
np
.
asarray
(
np
.
shape
(
inp
.
data
)[
i
])
return
np
.
asarray
(
np
.
shape
(
inp
.
data
)[
i
])
# The shape of a broadcastable dimension is 1
if
isinstance
(
inp
.
type
,
HasShape
)
and
inp
.
type
.
shape
[
i
]
is
not
None
:
if
isinstance
(
inp
.
type
,
HasShape
)
and
inp
.
type
.
shape
[
i
]
is
not
None
:
return
np
.
asarray
(
inp
.
type
.
shape
[
i
])
return
np
.
asarray
(
inp
.
type
.
shape
[
i
])
...
@@ -600,7 +601,10 @@ def get_scalar_constant_value(
...
@@ -600,7 +601,10 @@ def get_scalar_constant_value(
If 'v' is not a scalar, it raises a NotScalarConstantError.
If 'v' is not a scalar, it raises a NotScalarConstantError.
"""
"""
if
isinstance
(
v
,
TensorVariable
|
np
.
ndarray
):
if
isinstance
(
v
,
Variable
)
and
isinstance
(
v
.
type
,
HasShape
):
if
v
.
type
.
ndim
!=
0
:
raise
NotScalarConstantError
(
"Input ndim != 0"
)
elif
isinstance
(
v
,
np
.
ndarray
):
if
v
.
ndim
!=
0
:
if
v
.
ndim
!=
0
:
raise
NotScalarConstantError
(
"Input ndim != 0"
)
raise
NotScalarConstantError
(
"Input ndim != 0"
)
return
get_underlying_scalar_constant_value
(
return
get_underlying_scalar_constant_value
(
...
...
pytensor/tensor/subtensor.py
浏览文件 @
7fa753b5
...
@@ -805,7 +805,7 @@ def get_constant_idx(
...
@@ -805,7 +805,7 @@ def get_constant_idx(
>>> b.owner.op.idx_list
>>> b.owner.op.idx_list
(ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None))
(ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None))
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
[v, slice(
np.int64(1), np.int64(3)
, None)]
[v, slice(
1, 3
, None)]
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
Traceback (most recent call last):
Traceback (most recent call last):
pytensor.tensor.exceptions.NotScalarConstantError
pytensor.tensor.exceptions.NotScalarConstantError
...
@@ -825,7 +825,7 @@ def get_constant_idx(
...
@@ -825,7 +825,7 @@ def get_constant_idx(
val
,
val
,
only_process_constants
=
only_process_constants
,
only_process_constants
=
only_process_constants
,
elemwise
=
elemwise
,
elemwise
=
elemwise
,
)
)
.
item
()
except
NotScalarConstantError
:
except
NotScalarConstantError
:
if
allow_partial
:
if
allow_partial
:
return
val
return
val
...
...
pytensor/xtensor/rewriting/shape.py
浏览文件 @
7fa753b5
...
@@ -119,7 +119,6 @@ def lower_expand_dims(fgraph, node):
...
@@ -119,7 +119,6 @@ def lower_expand_dims(fgraph, node):
# Convert inputs to tensors
# Convert inputs to tensors
x_tensor
=
tensor_from_xtensor
(
x
)
x_tensor
=
tensor_from_xtensor
(
x
)
size_tensor
=
tensor_from_xtensor
(
size
)
# Get the new dimension name and position
# Get the new dimension name and position
new_axis
=
0
# Always insert at front
new_axis
=
0
# Always insert at front
...
@@ -130,7 +129,7 @@ def lower_expand_dims(fgraph, node):
...
@@ -130,7 +129,7 @@ def lower_expand_dims(fgraph, node):
result_tensor
=
expand_dims
(
x_tensor
,
new_axis
)
result_tensor
=
expand_dims
(
x_tensor
,
new_axis
)
else
:
else
:
# Otherwise broadcast to the requested size
# Otherwise broadcast to the requested size
result_tensor
=
broadcast_to
(
x_tensor
,
(
size
_tensor
,
*
x_tensor
.
shape
))
result_tensor
=
broadcast_to
(
x_tensor
,
(
size
,
*
x_tensor
.
shape
))
# Preserve static shape information
# Preserve static shape information
result_tensor
=
specify_shape
(
result_tensor
,
out
.
type
.
shape
)
result_tensor
=
specify_shape
(
result_tensor
,
out
.
type
.
shape
)
...
...
pytensor/xtensor/shape.py
浏览文件 @
7fa753b5
...
@@ -123,7 +123,10 @@ class UnStack(XOp):
...
@@ -123,7 +123,10 @@ class UnStack(XOp):
raise
ValueError
(
raise
ValueError
(
f
"Number of unstacked lengths {len(unstacked_length)} must match number of unstacked dims {len(self.unstacked_dims)}"
f
"Number of unstacked lengths {len(unstacked_length)} must match number of unstacked dims {len(self.unstacked_dims)}"
)
)
unstacked_lengths
=
[
as_tensor
(
length
,
ndim
=
0
)
for
length
in
unstacked_length
]
unstacked_lengths
=
[
as_tensor
(
length
,
allow_xtensor_conversion
=
True
)
for
length
in
unstacked_length
]
if
not
all
(
length
.
dtype
in
discrete_dtypes
for
length
in
unstacked_lengths
):
if
not
all
(
length
.
dtype
in
discrete_dtypes
for
length
in
unstacked_lengths
):
raise
TypeError
(
"Unstacked lengths must be discrete dtypes."
)
raise
TypeError
(
"Unstacked lengths must be discrete dtypes."
)
...
@@ -441,7 +444,7 @@ class ExpandDims(XOp):
...
@@ -441,7 +444,7 @@ class ExpandDims(XOp):
if
self
.
dim
in
x
.
type
.
dims
:
if
self
.
dim
in
x
.
type
.
dims
:
raise
ValueError
(
f
"Dimension {self.dim} already exists in {x.type.dims}"
)
raise
ValueError
(
f
"Dimension {self.dim} already exists in {x.type.dims}"
)
size
=
as_
xtensor
(
size
,
dims
=
()
)
size
=
as_
tensor
(
size
,
allow_xtensor_conversion
=
True
)
if
not
(
size
.
dtype
in
integer_dtypes
and
size
.
ndim
==
0
):
if
not
(
size
.
dtype
in
integer_dtypes
and
size
.
ndim
==
0
):
raise
ValueError
(
f
"size should be an integer scalar, got {size.type}"
)
raise
ValueError
(
f
"size should be an integer scalar, got {size.type}"
)
try
:
try
:
...
...
pytensor/xtensor/vectorization.py
浏览文件 @
7fa753b5
...
@@ -16,6 +16,7 @@ from pytensor.graph.type import HasShape
...
@@ -16,6 +16,7 @@ from pytensor.graph.type import HasShape
from
pytensor.scalar
import
discrete_dtypes
from
pytensor.scalar
import
discrete_dtypes
from
pytensor.tensor
import
(
from
pytensor.tensor
import
(
TensorVariable
,
TensorVariable
,
as_tensor
,
broadcast_shape
,
broadcast_shape
,
broadcast_to
,
broadcast_to
,
tensor
,
tensor
,
...
@@ -232,7 +233,7 @@ class XRV(XOp, RNGConsumerOp):
...
@@ -232,7 +233,7 @@ class XRV(XOp, RNGConsumerOp):
)
)
extra_dim_lengths
=
[
extra_dim_lengths
=
[
as_
xtensor
(
dim_length
)
.
values
as_
tensor
(
dim_length
,
allow_xtensor_conversion
=
True
)
for
dim_length
in
extra_dim_lengths_and_params
[:
len
(
self
.
extra_dims
)]
for
dim_length
in
extra_dim_lengths_and_params
[:
len
(
self
.
extra_dims
)]
]
]
if
not
all
(
if
not
all
(
...
...
tests/tensor/test_basic.py
浏览文件 @
7fa753b5
...
@@ -3504,11 +3504,10 @@ class TestGetUnderlyingScalarConstantValue:
...
@@ -3504,11 +3504,10 @@ class TestGetUnderlyingScalarConstantValue:
assert
get_underlying_scalar_constant_value
(
s
)
==
c
.
data
assert
get_underlying_scalar_constant_value
(
s
)
==
c
.
data
def
test_copy
(
self
):
def
test_copy
(
self
):
# Make sure we do not return
a writeabl
e internal storage of a constant,
# Make sure we do not return
th
e internal storage of a constant,
# so we cannot change the value of a constant by mistake.
# so we cannot change the value of a constant by mistake.
c
=
constant
(
3
)
c
=
constant
(
3
)
d
=
get_scalar_constant_value
(
c
)
d
=
get_scalar_constant_value
(
c
)
with
pytest
.
raises
(
ValueError
,
match
=
"output array is read-only"
):
d
+=
1
d
+=
1
e
=
get_scalar_constant_value
(
c
)
e
=
get_scalar_constant_value
(
c
)
assert
e
==
3
,
(
c
,
d
,
e
)
assert
e
==
3
,
(
c
,
d
,
e
)
...
...
tests/xtensor/test_random.py
浏览文件 @
7fa753b5
...
@@ -132,6 +132,14 @@ def test_dtype():
...
@@ -132,6 +132,14 @@ def test_dtype():
assert
x
.
type
.
dtype
==
"float32"
assert
x
.
type
.
dtype
==
"float32"
def
test_static_shape
():
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
1
,
None
))
y
=
xtensor
(
"y"
,
dims
=
(
"c"
,
"d"
),
shape
=
(
2
,
None
))
out
=
normal
(
x
,
1
,
extra_dims
=
y
.
sizes
)
assert
out
.
type
.
dims
==
(
"c"
,
"d"
,
"a"
,
"b"
)
assert
out
.
type
.
shape
==
(
2
,
None
,
1
,
None
)
def
test_normal
():
def
test_normal
():
rng
=
random_generator_type
(
"rng"
)
rng
=
random_generator_type
(
"rng"
)
c_size
=
tensor
(
"c_size"
,
shape
=
(),
dtype
=
int
)
c_size
=
tensor
(
"c_size"
,
shape
=
(),
dtype
=
int
)
...
...
tests/xtensor/test_shape.py
浏览文件 @
7fa753b5
...
@@ -25,7 +25,7 @@ from pytensor.xtensor.shape import (
...
@@ -25,7 +25,7 @@ from pytensor.xtensor.shape import (
unstack
,
unstack
,
zeros_like
,
zeros_like
,
)
)
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
from
pytensor.xtensor.type
import
XTensorType
,
as_xtensor
,
xtensor
from
pytensor.xtensor.vectorization
import
vectorize_graph
from
pytensor.xtensor.vectorization
import
vectorize_graph
from
tests.xtensor.util
import
(
from
tests.xtensor.util
import
(
check_vectorization
,
check_vectorization
,
...
@@ -369,16 +369,22 @@ def test_expand_dims():
...
@@ -369,16 +369,22 @@ def test_expand_dims():
# Implicit size 1
# Implicit size 1
y
=
x
.
expand_dims
(
"country"
)
y
=
x
.
expand_dims
(
"country"
)
assert
y
.
type
.
dims
==
(
"country"
,
"city"
,
"year"
)
assert
y
.
type
.
shape
==
(
1
,
2
,
2
)
fn
=
xr_function
([
x
],
y
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
"country"
))
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
"country"
))
# Test with multiple dimensions
# Test with multiple dimensions
y
=
x
.
expand_dims
([
"country"
,
"state"
])
y
=
x
.
expand_dims
([
"country"
,
"state"
])
assert
y
.
type
.
dims
==
(
"country"
,
"state"
,
"city"
,
"year"
)
assert
y
.
type
.
shape
==
(
1
,
1
,
2
,
2
)
fn
=
xr_function
([
x
],
y
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
([
"country"
,
"state"
]))
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
([
"country"
,
"state"
]))
# Test with a dict of name-size pairs
# Test with a dict of name-size pairs
y
=
x
.
expand_dims
({
"country"
:
2
,
"state"
:
3
})
y
=
x
.
expand_dims
({
"country"
:
2
,
"state"
:
3
})
assert
y
.
type
.
dims
==
(
"country"
,
"state"
,
"city"
,
"year"
)
assert
y
.
type
.
shape
==
(
2
,
3
,
2
,
2
)
fn
=
xr_function
([
x
],
y
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
({
"country"
:
2
,
"state"
:
3
}))
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
({
"country"
:
2
,
"state"
:
3
}))
...
@@ -390,6 +396,8 @@ def test_expand_dims():
...
@@ -390,6 +396,8 @@ def test_expand_dims():
# Test with a dict of name-coord array pairs
# Test with a dict of name-coord array pairs
with
pytest
.
warns
(
UserWarning
,
match
=
"only its length is used"
):
with
pytest
.
warns
(
UserWarning
,
match
=
"only its length is used"
):
y
=
x
.
expand_dims
({
"country"
:
np
.
array
([
1
,
2
]),
"state"
:
np
.
array
([
3
,
4
,
5
])})
y
=
x
.
expand_dims
({
"country"
:
np
.
array
([
1
,
2
]),
"state"
:
np
.
array
([
3
,
4
,
5
])})
assert
y
.
type
.
dims
==
(
"country"
,
"state"
,
"city"
,
"year"
)
assert
y
.
type
.
shape
==
(
2
,
3
,
2
,
2
)
fn
=
xr_function
([
x
],
y
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
xr_assert_allclose
(
fn
(
x_test
),
fn
(
x_test
),
...
@@ -399,12 +407,16 @@ def test_expand_dims():
...
@@ -399,12 +407,16 @@ def test_expand_dims():
# Symbolic size 1
# Symbolic size 1
size_sym_1
=
scalar
(
"size_sym_1"
,
dtype
=
"int64"
)
size_sym_1
=
scalar
(
"size_sym_1"
,
dtype
=
"int64"
)
y
=
x
.
expand_dims
({
"country"
:
size_sym_1
})
y
=
x
.
expand_dims
({
"country"
:
size_sym_1
})
assert
y
.
type
.
dims
==
(
"country"
,
"city"
,
"year"
)
assert
y
.
type
.
shape
==
(
None
,
2
,
2
)
fn
=
xr_function
([
x
,
size_sym_1
],
y
)
fn
=
xr_function
([
x
,
size_sym_1
],
y
)
xr_assert_allclose
(
fn
(
x_test
,
1
),
x_test
.
expand_dims
({
"country"
:
1
}))
xr_assert_allclose
(
fn
(
x_test
,
1
),
x_test
.
expand_dims
({
"country"
:
1
}))
# Test with symbolic sizes in dict
# Test with symbolic sizes in dict
size_sym_2
=
scalar
(
"size_sym_2"
,
dtype
=
"int64"
)
size_sym_2
=
scalar
(
"size_sym_2"
,
dtype
=
"int64"
)
y
=
x
.
expand_dims
({
"country"
:
size_sym_1
,
"state"
:
size_sym_2
})
y
=
x
.
expand_dims
({
"country"
:
size_sym_1
,
"state"
:
size_sym_2
})
assert
y
.
type
.
dims
==
(
"country"
,
"state"
,
"city"
,
"year"
)
assert
y
.
type
.
shape
==
(
None
,
None
,
2
,
2
)
fn
=
xr_function
([
x
,
size_sym_1
,
size_sym_2
],
y
)
fn
=
xr_function
([
x
,
size_sym_1
,
size_sym_2
],
y
)
xr_assert_allclose
(
fn
(
x_test
,
2
,
3
),
x_test
.
expand_dims
({
"country"
:
2
,
"state"
:
3
}))
xr_assert_allclose
(
fn
(
x_test
,
2
,
3
),
x_test
.
expand_dims
({
"country"
:
2
,
"state"
:
3
}))
...
@@ -415,16 +427,24 @@ def test_expand_dims():
...
@@ -415,16 +427,24 @@ def test_expand_dims():
# Test with axis parameter
# Test with axis parameter
y
=
x
.
expand_dims
(
"country"
,
axis
=
1
)
y
=
x
.
expand_dims
(
"country"
,
axis
=
1
)
assert
y
.
type
==
XTensorType
(
dtype
=
x
.
dtype
,
dims
=
(
"city"
,
"country"
,
"year"
),
shape
=
(
2
,
1
,
2
)
)
fn
=
xr_function
([
x
],
y
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
"country"
,
axis
=
1
))
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
"country"
,
axis
=
1
))
# Test with negative axis parameter
# Test with negative axis parameter
y
=
x
.
expand_dims
(
"country"
,
axis
=-
1
)
y
=
x
.
expand_dims
(
"country"
,
axis
=-
1
)
assert
y
.
type
==
XTensorType
(
dtype
=
x
.
dtype
,
dims
=
(
"city"
,
"year"
,
"country"
),
shape
=
(
2
,
2
,
1
)
)
fn
=
xr_function
([
x
],
y
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
"country"
,
axis
=-
1
))
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
"country"
,
axis
=-
1
))
# Add two new dims with axis parameters
# Add two new dims with axis parameters
y
=
x
.
expand_dims
([
"country"
,
"state"
],
axis
=
[
1
,
2
])
y
=
x
.
expand_dims
([
"country"
,
"state"
],
axis
=
[
1
,
2
])
assert
y
.
type
.
dims
==
(
"city"
,
"country"
,
"state"
,
"year"
)
assert
y
.
type
.
shape
==
(
2
,
1
,
1
,
2
)
fn
=
xr_function
([
x
],
y
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
([
"country"
,
"state"
],
axis
=
[
1
,
2
])
fn
(
x_test
),
x_test
.
expand_dims
([
"country"
,
"state"
],
axis
=
[
1
,
2
])
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论