Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b979dd6e
提交
b979dd6e
authored
2月 14, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
5月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add FunctionGraphs to the HasInnerGraph interface and OpFromGraph, Scan
上级
980ecacf
隐藏空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
168 行增加
和
93 行删除
+168
-93
builders.py
aesara/compile/builders.py
+32
-24
gradient.py
aesara/gradient.py
+11
-10
basic.py
aesara/graph/basic.py
+35
-0
op.py
aesara/graph/op.py
+3
-0
printing.py
aesara/printing.py
+8
-8
op.py
aesara/scan/op.py
+27
-4
test_builders.py
tests/compile/test_builders.py
+4
-8
test_basic.py
tests/scan/test_basic.py
+12
-9
test_printing.py
tests/scan/test_printing.py
+36
-30
没有找到文件。
aesara/compile/builders.py
浏览文件 @
b979dd6e
...
@@ -13,10 +13,12 @@ from aesara.gradient import DisconnectedType, Rop, grad
...
@@ -13,10 +13,12 @@ from aesara.gradient import DisconnectedType, Rop, grad
from
aesara.graph.basic
import
(
from
aesara.graph.basic
import
(
Apply
,
Apply
,
Constant
,
Constant
,
NominalVariable
,
Variable
,
Variable
,
clone_replace
,
clone_replace
,
graph_inputs
,
graph_inputs
,
io_connection_pattern
,
io_connection_pattern
,
replace_nominals_with_dummies
,
)
)
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.null_type
import
NullType
from
aesara.graph.null_type
import
NullType
...
@@ -349,17 +351,32 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -349,17 +351,32 @@ class OpFromGraph(Op, HasInnerGraph):
raise
NotImplementedError
(
"Updates and givens are not allowed here"
)
raise
NotImplementedError
(
"Updates and givens are not allowed here"
)
self
.
is_inline
=
inline
self
.
is_inline
=
inline
# To correctly support shared variables the inner fct should
# To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient.
# not see them. Otherwise there is a problem with the gradient.
self
.
shared_inputs
=
[
self
.
shared_inputs
=
[]
var
for
var
in
graph_inputs
(
outputs
)
if
isinstance
(
var
,
SharedVariable
)
for
var
in
graph_inputs
(
outputs
):
if
isinstance
(
var
,
SharedVariable
):
self
.
shared_inputs
.
append
(
var
)
inputs
,
outputs
=
replace_nominals_with_dummies
(
inputs
,
outputs
)
# The inputs should be `NominalVariable`s, so that graphs can be merged
replacements
=
{}
for
n
,
v
in
enumerate
(
inputs
):
replacements
[
v
]
=
NominalVariable
(
n
,
v
.
type
)
shared_vars
=
[
NominalVariable
(
n
,
var
.
type
)
for
n
,
var
in
enumerate
(
self
.
shared_inputs
,
start
=
len
(
inputs
)
+
1
)
]
]
shared_vars
=
[
var
.
type
()
for
var
in
self
.
shared_inputs
]
replacements
.
update
(
dict
(
zip
(
self
.
shared_inputs
,
shared_vars
)))
new
=
rebuild_collect_shared
(
new
=
rebuild_collect_shared
(
cast
(
Sequence
[
Variable
],
outputs
),
cast
(
Sequence
[
Variable
],
outputs
),
inputs
=
inputs
+
shared_vars
,
inputs
=
inputs
+
shared_vars
,
replace
=
dict
(
zip
(
self
.
shared_inputs
,
shared_vars
))
,
replace
=
replacements
,
copy_inputs_over
=
False
,
copy_inputs_over
=
False
,
)
)
(
(
...
@@ -374,10 +391,7 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -374,10 +391,7 @@ class OpFromGraph(Op, HasInnerGraph):
assert
not
update_expr
assert
not
update_expr
assert
not
shared_inputs
assert
not
shared_inputs
self
.
_inner_inputs
=
local_inputs
self
.
fgraph
=
FunctionGraph
(
local_inputs
,
local_outputs
,
clone
=
False
)
self
.
_inner_outputs
=
local_outputs
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
kwargs
=
kwargs
self
.
kwargs
=
kwargs
self
.
input_types
=
[
inp
.
type
for
inp
in
inputs
]
self
.
input_types
=
[
inp
.
type
for
inp
in
inputs
]
self
.
output_types
=
[
out
.
type
for
out
in
outputs
]
self
.
output_types
=
[
out
.
type
for
out
in
outputs
]
...
@@ -778,29 +792,23 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -778,29 +792,23 @@ class OpFromGraph(Op, HasInnerGraph):
# The shared variables are not equal to the original shared
# The shared variables are not equal to the original shared
# variables, so we construct a new `Op` that uses the new shared
# variables, so we construct a new `Op` that uses the new shared
# variables instead.
# variables instead.
# All this is really doing is making the unused (internally, at
replace
=
dict
(
# least) `self.outputs` and `self.shared_inputs` consistent.
zip
(
self
.
inner_inputs
[
num_expected_inps
:],
new_shared_inputs
)
# We could just as easily `copy` this `Op`, update
)
# `self.shared_inputs`, and avoid cloning anything, but this is a
# more "change-proof" approach, because it still work when/if those
# attributes end up being used.
replace
=
dict
(
inner_and_input_shareds
)
# If the new shared variables are inconsistent with the inner-graph,
# If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step
# such errors should arise in this step
new_inner_outputs
=
clone_replace
(
new_inner_outputs
=
clone_replace
(
self
.
outputs
,
replace
=
replace
,
share_inputs
=
True
self
.
inner_
outputs
,
replace
=
replace
,
share_inputs
=
True
)
)
# `self.inputs` should not contain any shared variables, so we know
# It's possible that the new shared variable inputs aren't actually
# that those are inputs to `new_outputs`, because we chose not to
# shared variables. When they aren't we need to add them as new
# clone inputs; however, it's possible that the new shared variable
# inputs.
# inputs aren't actually shared variables. When they aren't we
# need to add them as new inputs.
unshared_inputs
=
[
unshared_inputs
=
[
inp
for
inp
in
new_shared_inputs
if
not
isinstance
(
inp
,
SharedVariable
)
inp
for
inp
in
new_shared_inputs
if
not
isinstance
(
inp
,
SharedVariable
)
]
]
new_inner_inputs
=
self
.
in
puts
+
unshared_inputs
new_inner_inputs
=
self
.
in
ner_inputs
[:
num_expected_inps
]
+
unshared_inputs
new_op
=
type
(
self
)(
new_op
=
type
(
self
)(
inputs
=
new_inner_inputs
,
inputs
=
new_inner_inputs
,
...
@@ -901,11 +909,11 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -901,11 +909,11 @@ class OpFromGraph(Op, HasInnerGraph):
@property
@property
def
inner_inputs
(
self
):
def
inner_inputs
(
self
):
return
self
.
_inner_
inputs
return
self
.
fgraph
.
inputs
@property
@property
def
inner_outputs
(
self
):
def
inner_outputs
(
self
):
return
self
.
_inner_
outputs
return
self
.
fgraph
.
outputs
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
variables
=
self
.
fn
(
*
inputs
)
variables
=
self
.
fn
(
*
inputs
)
...
...
aesara/gradient.py
浏览文件 @
b979dd6e
...
@@ -13,7 +13,7 @@ import aesara
...
@@ -13,7 +13,7 @@ import aesara
from
aesara.compile.ops
import
ViewOp
from
aesara.compile.ops
import
ViewOp
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph
import
utils
from
aesara.graph
import
utils
from
aesara.graph.basic
import
Variable
from
aesara.graph.basic
import
NominalVariable
,
Variable
from
aesara.graph.null_type
import
NullType
,
null_type
from
aesara.graph.null_type
import
NullType
,
null_type
from
aesara.graph.op
import
get_test_values
from
aesara.graph.op
import
get_test_values
from
aesara.graph.type
import
Type
from
aesara.graph.type
import
Type
...
@@ -1295,15 +1295,16 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
...
@@ -1295,15 +1295,16 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# has the right shape
# has the right shape
if
hasattr
(
term
,
"shape"
):
if
hasattr
(
term
,
"shape"
):
orig_ipt
=
inputs
[
i
]
orig_ipt
=
inputs
[
i
]
for
orig_ipt_v
,
term_v
in
get_test_values
(
orig_ipt
,
term
):
if
not
isinstance
(
orig_ipt
,
NominalVariable
):
i_shape
=
orig_ipt_v
.
shape
for
orig_ipt_v
,
term_v
in
get_test_values
(
orig_ipt
,
term
):
t_shape
=
term_v
.
shape
i_shape
=
orig_ipt_v
.
shape
if
i_shape
!=
t_shape
:
t_shape
=
term_v
.
shape
raise
ValueError
(
if
i_shape
!=
t_shape
:
f
"{node.op}.grad returned object of "
raise
ValueError
(
f
"shape {t_shape} as gradient term on input {int(i)} "
f
"{node.op}.grad returned object of "
f
"of shape {i_shape}"
f
"shape {t_shape} as gradient term on input {int(i)} "
)
f
"of shape {i_shape}"
)
if
not
isinstance
(
term
.
type
,
(
NullType
,
DisconnectedType
)):
if
not
isinstance
(
term
.
type
,
(
NullType
,
DisconnectedType
)):
if
term
.
type
.
dtype
not
in
aesara
.
tensor
.
type
.
float_dtypes
:
if
term
.
type
.
dtype
not
in
aesara
.
tensor
.
type
.
float_dtypes
:
...
...
aesara/graph/basic.py
浏览文件 @
b979dd6e
...
@@ -1755,3 +1755,38 @@ def get_var_by_name(
...
@@ -1755,3 +1755,38 @@ def get_var_by_name(
results
+=
(
var
,)
results
+=
(
var
,)
return
results
return
results
def
replace_nominals_with_dummies
(
inputs
,
outputs
):
"""Replace nominal inputs with dummy variables.
When constructing a new graph with nominal inputs from an existing graph,
pre-existing nominal inputs need to be replaced with dummy variables
beforehand; otherwise, sequential ID ordering (i.e. when nominals are IDed
based on the ordered inputs to which they correspond) of the nominals could
be broken, and/or circular replacements could manifest.
FYI: This function assumes that all the nominal variables in the subgraphs
between `inputs` and `outputs` are present in `inputs`.
"""
existing_nominal_replacements
=
{
i
:
i
.
type
()
for
i
in
inputs
if
isinstance
(
i
,
NominalVariable
)
}
if
existing_nominal_replacements
:
# Replace existing nominal variables, because we need to produce an
# inner-graph for which the nominal variable IDs correspond exactly
# to their input order
_
=
clone_get_equiv
(
inputs
,
outputs
,
copy_inputs
=
False
,
copy_orphans
=
False
,
memo
=
existing_nominal_replacements
,
)
outputs
=
[
existing_nominal_replacements
[
o
]
for
o
in
outputs
]
inputs
=
[
existing_nominal_replacements
[
i
]
for
i
in
inputs
]
return
inputs
,
outputs
aesara/graph/op.py
浏览文件 @
b979dd6e
...
@@ -615,6 +615,9 @@ class _NoPythonOp(Op):
...
@@ -615,6 +615,9 @@ class _NoPythonOp(Op):
class
HasInnerGraph
:
class
HasInnerGraph
:
r"""A mixin for an `Op` that contain an inner graph."""
r"""A mixin for an `Op` that contain an inner graph."""
fgraph
:
"FunctionGraph"
"""A `FunctionGraph` of the inner function."""
@property
@property
@abstractmethod
@abstractmethod
def
fn
(
self
)
->
"Function"
:
def
fn
(
self
)
->
"Function"
:
...
...
aesara/printing.py
浏览文件 @
b979dd6e
...
@@ -375,6 +375,7 @@ N.B.:
...
@@ -375,6 +375,7 @@ N.B.:
print_op_info
=
print_op_info
,
print_op_info
=
print_op_info
,
print_destroy_map
=
print_destroy_map
,
print_destroy_map
=
print_destroy_map
,
print_view_map
=
print_view_map
,
print_view_map
=
print_view_map
,
inner_graph_node
=
s
.
owner
,
)
)
if
file
is
_file
:
if
file
is
_file
:
...
@@ -407,6 +408,7 @@ def _debugprint(
...
@@ -407,6 +408,7 @@ def _debugprint(
op_information
:
Optional
[
Dict
[
Apply
,
Dict
[
Variable
,
str
]]]
=
None
,
op_information
:
Optional
[
Dict
[
Apply
,
Dict
[
Variable
,
str
]]]
=
None
,
parent_node
:
Optional
[
Apply
]
=
None
,
parent_node
:
Optional
[
Apply
]
=
None
,
print_op_info
:
bool
=
False
,
print_op_info
:
bool
=
False
,
inner_graph_node
:
Optional
[
Apply
]
=
None
,
)
->
IOBase
:
)
->
IOBase
:
r"""Print the graph leading to `r`.
r"""Print the graph leading to `r`.
...
@@ -459,6 +461,8 @@ def _debugprint(
...
@@ -459,6 +461,8 @@ def _debugprint(
print_op_info
print_op_info
Print extra information provided by the relevant `Op`\s. For example,
Print extra information provided by the relevant `Op`\s. For example,
print the tap information for `Scan` inputs and outputs.
print the tap information for `Scan` inputs and outputs.
inner_graph_node
The inner-graph node in which `r` is contained.
"""
"""
if
depth
==
0
:
if
depth
==
0
:
return
file
return
file
...
@@ -615,6 +619,7 @@ def _debugprint(
...
@@ -615,6 +619,7 @@ def _debugprint(
print_op_info
=
print_op_info
,
print_op_info
=
print_op_info
,
print_destroy_map
=
print_destroy_map
,
print_destroy_map
=
print_destroy_map
,
print_view_map
=
print_view_map
,
print_view_map
=
print_view_map
,
inner_graph_node
=
inner_graph_node
,
)
)
else
:
else
:
...
@@ -644,14 +649,9 @@ def _debugprint(
...
@@ -644,14 +649,9 @@ def _debugprint(
var_output
=
f
"{var_output} -> {outer_id_str}"
var_output
=
f
"{var_output} -> {outer_id_str}"
# This is an inner-graph input, so we need to find the outer node
node_info
=
op_information
.
get
(
inner_graph_node
)
# it belongs to and get the extra information from that
if
node_info
and
r
in
node_info
:
for
inner_graph
in
inner_graph_ops
:
var_output
=
f
"{var_output} ({node_info[r]})"
if
outer_r
in
inner_graph
.
owner
.
inputs
:
node_info
=
op_information
.
get
(
inner_graph
.
owner
)
if
node_info
and
r
in
node_info
:
var_output
=
f
"{var_output} ({node_info[r]})"
break
node_info
=
op_information
.
get
(
parent_node
)
or
op_information
.
get
(
r
.
owner
)
node_info
=
op_information
.
get
(
parent_node
)
or
op_information
.
get
(
r
.
owner
)
if
node_info
and
r
in
node_info
:
if
node_info
and
r
in
node_info
:
...
...
aesara/scan/op.py
浏览文件 @
b979dd6e
...
@@ -54,6 +54,7 @@ import numpy as np
...
@@ -54,6 +54,7 @@ import numpy as np
import
aesara
import
aesara
from
aesara
import
tensor
as
at
from
aesara
import
tensor
as
at
from
aesara.compile
import
SharedVariable
from
aesara.compile.builders
import
infer_shape
from
aesara.compile.builders
import
infer_shape
from
aesara.compile.function
import
function
from
aesara.compile.function
import
function
from
aesara.compile.io
import
In
,
Out
from
aesara.compile.io
import
In
,
Out
...
@@ -64,13 +65,16 @@ from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefine
...
@@ -64,13 +65,16 @@ from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefine
from
aesara.graph.basic
import
(
from
aesara.graph.basic
import
(
Apply
,
Apply
,
Constant
,
Constant
,
NominalVariable
,
Variable
,
Variable
,
clone_replace
,
clone_replace
,
equal_computations
,
equal_computations
,
graph_inputs
,
graph_inputs
,
io_connection_pattern
,
io_connection_pattern
,
replace_nominals_with_dummies
,
)
)
from
aesara.graph.features
import
NoOutputFromInplace
from
aesara.graph.features
import
NoOutputFromInplace
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
HasInnerGraph
,
Op
from
aesara.graph.op
import
HasInnerGraph
,
Op
from
aesara.graph.utils
import
MissingInputError
from
aesara.graph.utils
import
MissingInputError
from
aesara.link.c.basic
import
CLinker
from
aesara.link.c.basic
import
CLinker
...
@@ -757,8 +761,27 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -757,8 +761,27 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
If ``True``, all the shared variables used in the inner-graph must be provided.
If ``True``, all the shared variables used in the inner-graph must be provided.
"""
"""
self
.
inputs
=
inputs
inputs
,
outputs
=
replace_nominals_with_dummies
(
inputs
,
outputs
)
self
.
outputs
=
outputs
input_replacements
=
[]
for
n
,
v
in
enumerate
(
inputs
):
if
not
isinstance
(
v
,
(
SharedVariable
,
Constant
)):
input_replacements
.
append
((
v
,
NominalVariable
(
n
,
v
.
type
)))
assert
not
isinstance
(
v
,
NominalVariable
)
outputs
=
clone_replace
(
outputs
,
replace
=
input_replacements
)
if
input_replacements
:
_
,
inputs_
=
zip
(
*
input_replacements
)
inputs
=
list
(
inputs_
)
else
:
inputs
=
[]
self
.
fgraph
=
FunctionGraph
(
inputs
,
outputs
,
clone
=
False
)
self
.
inputs
=
self
.
fgraph
.
inputs
self
.
outputs
=
self
.
fgraph
.
outputs
self
.
info
=
info
self
.
info
=
info
self
.
truncate_gradient
=
truncate_gradient
self
.
truncate_gradient
=
truncate_gradient
self
.
name
=
name
self
.
name
=
name
...
@@ -3416,8 +3439,8 @@ def _op_debug_information_Scan(op, node):
...
@@ -3416,8 +3439,8 @@ def _op_debug_information_Scan(op, node):
inner_inputs
=
inner_fn
.
maker
.
fgraph
.
inputs
inner_inputs
=
inner_fn
.
maker
.
fgraph
.
inputs
inner_outputs
=
inner_fn
.
maker
.
fgraph
.
outputs
inner_outputs
=
inner_fn
.
maker
.
fgraph
.
outputs
else
:
else
:
inner_inputs
=
op
.
inputs
inner_inputs
=
op
.
in
ner_in
puts
inner_outputs
=
op
.
outputs
inner_outputs
=
op
.
inner_
outputs
scan_args
=
ScanArgs
(
scan_args
=
ScanArgs
(
node
.
inputs
,
node
.
inputs
,
...
...
tests/compile/test_builders.py
浏览文件 @
b979dd6e
...
@@ -466,7 +466,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
...
@@ -466,7 +466,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
y
=
shared
(
1.0
,
name
=
"y"
)
y
=
shared
(
1.0
,
name
=
"y"
)
test_ofg
=
OpFromGraph
([
x
],
[
x
+
y
],
on_unused_input
=
"ignore"
)
test_ofg
=
OpFromGraph
([
x
],
[
x
+
y
],
on_unused_input
=
"ignore"
)
assert
test_ofg
.
inputs
==
[
x
]
assert
test_ofg
.
shared_inputs
==
[
y
]
assert
test_ofg
.
shared_inputs
==
[
y
]
out
=
test_ofg
(
x
)
out
=
test_ofg
(
x
)
...
@@ -478,7 +477,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
...
@@ -478,7 +477,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
out_new
=
test_ofg
.
make_node
(
*
(
out
.
owner
.
inputs
[:
1
]
+
[
y_clone
]))
.
outputs
[
0
]
out_new
=
test_ofg
.
make_node
(
*
(
out
.
owner
.
inputs
[:
1
]
+
[
y_clone
]))
.
outputs
[
0
]
assert
"on_unused_input"
in
out_new
.
owner
.
op
.
kwargs
assert
"on_unused_input"
in
out_new
.
owner
.
op
.
kwargs
assert
out_new
.
owner
.
op
.
inputs
==
[
x
]
assert
out_new
.
owner
.
op
.
shared_inputs
==
[
y_clone
]
assert
out_new
.
owner
.
op
.
shared_inputs
==
[
y_clone
]
out_fn
=
function
([
x
],
out_new
)
out_fn
=
function
([
x
],
out_new
)
...
@@ -497,7 +495,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
...
@@ -497,7 +495,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
y
=
shared
(
1.0
,
name
=
"y"
)
y
=
shared
(
1.0
,
name
=
"y"
)
test_ofg
=
OpFromGraph
([
x
],
[
x
+
y
])
test_ofg
=
OpFromGraph
([
x
],
[
x
+
y
])
assert
test_ofg
.
inputs
==
[
x
]
assert
test_ofg
.
shared_inputs
==
[
y
]
assert
test_ofg
.
shared_inputs
==
[
y
]
out
=
test_ofg
(
at
.
as_tensor
(
1.0
,
dtype
=
config
.
floatX
))
out
=
test_ofg
(
at
.
as_tensor
(
1.0
,
dtype
=
config
.
floatX
))
...
@@ -517,7 +514,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
...
@@ -517,7 +514,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
y
=
shared
(
1.0
,
name
=
"y"
)
y
=
shared
(
1.0
,
name
=
"y"
)
test_ofg
=
OpFromGraph
([],
[
y
])
test_ofg
=
OpFromGraph
([],
[
y
])
assert
test_ofg
.
inputs
==
[]
assert
test_ofg
.
shared_inputs
==
[
y
]
assert
test_ofg
.
shared_inputs
==
[
y
]
out_1_fn
=
function
([],
test_ofg
())
out_1_fn
=
function
([],
test_ofg
())
...
@@ -526,7 +522,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
...
@@ -526,7 +522,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert
np
.
array_equal
(
res_1
,
1.0
)
assert
np
.
array_equal
(
res_1
,
1.0
)
test_ofg_new
=
test_ofg
.
make_node
(
x
)
test_ofg_new
=
test_ofg
.
make_node
(
x
)
assert
test_ofg_new
.
op
.
inputs
==
[
x
]
assert
test_ofg_new
.
op
.
shared_inputs
==
[]
assert
test_ofg_new
.
op
.
shared_inputs
==
[]
out_2_fn
=
function
([
x
],
test_ofg_new
.
outputs
[
0
])
out_2_fn
=
function
([
x
],
test_ofg_new
.
outputs
[
0
])
...
@@ -535,6 +530,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
...
@@ -535,6 +530,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert
np
.
array_equal
(
res_2
,
1.0
)
assert
np
.
array_equal
(
res_2
,
1.0
)
@config.change_flags
(
floatX
=
"float64"
)
def
test_debugprint
():
def
test_debugprint
():
x
,
y
,
z
=
matrices
(
"xyz"
)
x
,
y
,
z
=
matrices
(
"xyz"
)
e
=
x
+
y
*
z
e
=
x
+
y
*
z
...
@@ -553,10 +549,10 @@ Inner graphs:
...
@@ -553,10 +549,10 @@ Inner graphs:
OpFromGraph{inline=False} [id A]
OpFromGraph{inline=False} [id A]
>Elemwise{add,no_inplace} [id E]
>Elemwise{add,no_inplace} [id E]
> |
x
[id F]
> |
*0-<TensorType(float64, (None, None))>
[id F]
> |Elemwise{mul,no_inplace} [id G]
> |Elemwise{mul,no_inplace} [id G]
> |
y
[id H]
> |
*1-<TensorType(float64, (None, None))>
[id H]
> |
z
[id I]
> |
*2-<TensorType(float64, (None, None))>
[id I]
"""
"""
for
truth
,
out
in
zip
(
exp_res
.
split
(
"
\n
"
),
lines
):
for
truth
,
out
in
zip
(
exp_res
.
split
(
"
\n
"
),
lines
):
...
...
tests/scan/test_basic.py
浏览文件 @
b979dd6e
...
@@ -2355,9 +2355,11 @@ def test_compute_test_values():
...
@@ -2355,9 +2355,11 @@ def test_compute_test_values():
assert
np
.
array_equal
(
z_grad
.
tag
.
test_value
,
np
.
r_
[
9.0
,
9.0
,
9.0
])
assert
np
.
array_equal
(
z_grad
.
tag
.
test_value
,
np
.
r_
[
9.0
,
9.0
,
9.0
])
@pytest.mark.xfail
(
reason
=
"NominalVariables don't support test values"
)
def
test_compute_test_value_grad
():
def
test_compute_test_value_grad
():
# Test case originally reported by Bitton Tenessi
"""
# https://groups.google.com/d/msg/theano-users/fAP3i2CbskQ/3OgBf4yjqiQJ
See https://groups.google.com/d/msg/theano-users/fAP3i2CbskQ/3OgBf4yjqiQJ
"""
WEIGHT
=
np
.
array
([
1
,
2
,
1
,
3
,
4
,
1
,
5
,
6
,
1
,
7
,
8
,
1
],
dtype
=
"float32"
)
WEIGHT
=
np
.
array
([
1
,
2
,
1
,
3
,
4
,
1
,
5
,
6
,
1
,
7
,
8
,
1
],
dtype
=
"float32"
)
with
config
.
change_flags
(
compute_test_value
=
"raise"
,
exception_verbosity
=
"high"
):
with
config
.
change_flags
(
compute_test_value
=
"raise"
,
exception_verbosity
=
"high"
):
...
@@ -2395,10 +2397,12 @@ def test_compute_test_value_grad():
...
@@ -2395,10 +2397,12 @@ def test_compute_test_value_grad():
grad
(
loss
,
W_flat
)
grad
(
loss
,
W_flat
)
@pytest.mark.xfail
(
reason
=
"NominalVariables don't support test values"
)
def
test_compute_test_value_grad_cast
():
def
test_compute_test_value_grad_cast
():
# Test for test values when variables have to be casted
"""Test for test values when variables have to be casted.
# Reported by Daniel Renshaw at
# https://groups.google.com/d/topic/theano-users/o4jK9xDe5WI/discussion
See https://groups.google.com/d/topic/theano-users/o4jK9xDe5WI/discussion
"""
with
config
.
change_flags
(
compute_test_value
=
"raise"
):
with
config
.
change_flags
(
compute_test_value
=
"raise"
):
h
=
matrix
(
"h"
)
h
=
matrix
(
"h"
)
h
.
tag
.
test_value
=
np
.
array
([[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
]],
dtype
=
config
.
floatX
)
h
.
tag
.
test_value
=
np
.
array
([[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
]],
dtype
=
config
.
floatX
)
...
@@ -2434,7 +2438,7 @@ def test_constant_folding_n_steps():
...
@@ -2434,7 +2438,7 @@ def test_constant_folding_n_steps():
def
test_outputs_taps_check
():
def
test_outputs_taps_check
():
# Checks that errors are raised with bad output_info taps.
"""Checks that errors are raised with bad output_info taps."""
x
=
fvector
(
"x"
)
x
=
fvector
(
"x"
)
y
=
fvector
(
"y"
)
y
=
fvector
(
"y"
)
...
@@ -2462,7 +2466,6 @@ def test_inconsistent_broadcast_error():
...
@@ -2462,7 +2466,6 @@ def test_inconsistent_broadcast_error():
grad
(
y
.
sum
(),
x
)
grad
(
y
.
sum
(),
x
)
@pytest.mark.xfail
(
raises
=
MissingInputError
)
def
test_missing_input_error
():
def
test_missing_input_error
():
c
=
shared
(
0.0
)
c
=
shared
(
0.0
)
inc
=
scalar
(
"inc"
)
inc
=
scalar
(
"inc"
)
...
@@ -2470,8 +2473,8 @@ def test_missing_input_error():
...
@@ -2470,8 +2473,8 @@ def test_missing_input_error():
def
count_up
():
def
count_up
():
return
at
.
zeros
(()),
{
c
:
c
+
inc
}
return
at
.
zeros
(()),
{
c
:
c
+
inc
}
_
,
updates
=
scan
(
count_up
,
n_steps
=
20
)
with
pytest
.
raises
(
MissingInputError
):
function
(
inputs
=
[
inc
],
outputs
=
[],
updates
=
updates
)
_
,
updates
=
scan
(
count_up
,
n_steps
=
20
)
class
TestGradUntil
:
class
TestGradUntil
:
...
...
tests/scan/test_printing.py
浏览文件 @
b979dd6e
...
@@ -3,10 +3,12 @@ import pytest
...
@@ -3,10 +3,12 @@ import pytest
import
aesara
import
aesara
import
aesara.tensor
as
at
import
aesara.tensor
as
at
from
aesara.configdefaults
import
config
from
aesara.printing
import
debugprint
,
pydot_imported
,
pydotprint
from
aesara.printing
import
debugprint
,
pydot_imported
,
pydotprint
from
aesara.tensor.type
import
dvector
,
iscalar
,
scalar
,
vector
from
aesara.tensor.type
import
dvector
,
iscalar
,
scalar
,
vector
@config.change_flags
(
floatX
=
"float64"
)
def
test_debugprint_sitsot
():
def
test_debugprint_sitsot
():
k
=
iscalar
(
"k"
)
k
=
iscalar
(
"k"
)
A
=
dvector
(
"A"
)
A
=
dvector
(
"A"
)
...
@@ -55,8 +57,8 @@ def test_debugprint_sitsot():
...
@@ -55,8 +57,8 @@ def test_debugprint_sitsot():
for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id W] (inner_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id W] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id X] -> [id E] (inner_in_sit_sot-0)
> |
*0-
<TensorType(float64, (None,))> [id X] -> [id E] (inner_in_sit_sot-0)
> |
A_copy
[id Y] -> [id M] (inner_in_non_seqs-0)"""
> |
*1-<TensorType(float64, (None,))>
[id Y] -> [id M] (inner_in_non_seqs-0)"""
for
truth
,
out
in
zip
(
expected_output
.
split
(
"
\n
"
),
lines
):
for
truth
,
out
in
zip
(
expected_output
.
split
(
"
\n
"
),
lines
):
assert
truth
.
strip
()
==
out
.
strip
()
assert
truth
.
strip
()
==
out
.
strip
()
...
@@ -110,13 +112,14 @@ def test_debugprint_sitsot_no_extra_info():
...
@@ -110,13 +112,14 @@ def test_debugprint_sitsot_no_extra_info():
for{cpu,scan_fn} [id C]
for{cpu,scan_fn} [id C]
>Elemwise{mul,no_inplace} [id W]
>Elemwise{mul,no_inplace} [id W]
> |<TensorType(float64, (None,))> [id X] -> [id E]
> |
*0-
<TensorType(float64, (None,))> [id X] -> [id E]
> |
A_copy
[id Y] -> [id M]"""
> |
*1-<TensorType(float64, (None,))>
[id Y] -> [id M]"""
for
truth
,
out
in
zip
(
expected_output
.
split
(
"
\n
"
),
lines
):
for
truth
,
out
in
zip
(
expected_output
.
split
(
"
\n
"
),
lines
):
assert
truth
.
strip
()
==
out
.
strip
()
assert
truth
.
strip
()
==
out
.
strip
()
@config.change_flags
(
floatX
=
"float64"
)
def
test_debugprint_nitsot
():
def
test_debugprint_nitsot
():
coefficients
=
vector
(
"coefficients"
)
coefficients
=
vector
(
"coefficients"
)
x
=
scalar
(
"x"
)
x
=
scalar
(
"x"
)
...
@@ -170,15 +173,16 @@ def test_debugprint_nitsot():
...
@@ -170,15 +173,16 @@ def test_debugprint_nitsot():
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
>Elemwise{mul,no_inplace} [id X] (inner_out_nit_sot-0)
>Elemwise{mul,no_inplace} [id X] (inner_out_nit_sot-0)
> |
coefficients[t]
[id Y] -> [id S] (inner_in_seqs-0)
> |
*0-<TensorType(float64, ())>
[id Y] -> [id S] (inner_in_seqs-0)
> |Elemwise{pow,no_inplace} [id Z]
> |Elemwise{pow,no_inplace} [id Z]
> |
x_copy
[id BA] -> [id W] (inner_in_non_seqs-0)
> |
*2-<TensorType(float64, ())>
[id BA] -> [id W] (inner_in_non_seqs-0)
> |<TensorType(int64, ())> [id BB] -> [id U] (inner_in_seqs-1)"""
> |
*1-
<TensorType(int64, ())> [id BB] -> [id U] (inner_in_seqs-1)"""
for
truth
,
out
in
zip
(
expected_output
.
split
(
"
\n
"
),
lines
):
for
truth
,
out
in
zip
(
expected_output
.
split
(
"
\n
"
),
lines
):
assert
truth
.
strip
()
==
out
.
strip
()
assert
truth
.
strip
()
==
out
.
strip
()
@config.change_flags
(
floatX
=
"float64"
)
def
test_debugprint_nested_scans
():
def
test_debugprint_nested_scans
():
coefficients
=
dvector
(
"coefficients"
)
coefficients
=
dvector
(
"coefficients"
)
max_coefficients_supported
=
10
max_coefficients_supported
=
10
...
@@ -251,22 +255,22 @@ def test_debugprint_nested_scans():
...
@@ -251,22 +255,22 @@ def test_debugprint_nested_scans():
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
>Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0)
>Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0)
> |InplaceDimShuffle{x} [id Z]
> |InplaceDimShuffle{x} [id Z]
> | |
coefficients[t]
[id BA] -> [id S] (inner_in_seqs-0)
> | |
*0-<TensorType(float64, ())>
[id BA] -> [id S] (inner_in_seqs-0)
> |Elemwise{pow,no_inplace} [id BB]
> |Elemwise{pow,no_inplace} [id BB]
> |Subtensor{int64} [id BC]
> |Subtensor{int64} [id BC]
> | |Subtensor{int64::} [id BD]
> | |Subtensor{int64::} [id BD]
> | | |for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
> | | |for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
> | | | |
k_copy
[id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
> | | | |
*3-<TensorType(int32, ())>
[id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
> | | | |IncSubtensor{Set;:int64:} [id BG] (outer_in_sit_sot-0)
> | | | |IncSubtensor{Set;:int64:} [id BG] (outer_in_sit_sot-0)
> | | | | |AllocEmpty{dtype='float64'} [id BH]
> | | | | |AllocEmpty{dtype='float64'} [id BH]
> | | | | | |Elemwise{add,no_inplace} [id BI]
> | | | | | |Elemwise{add,no_inplace} [id BI]
> | | | | | | |
k_copy
[id BF] -> [id X] (inner_in_non_seqs-1)
> | | | | | | |
*3-<TensorType(int32, ())>
[id BF] -> [id X] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BJ]
> | | | | | | |Subtensor{int64} [id BJ]
> | | | | | | |Shape [id BK]
> | | | | | | |Shape [id BK]
> | | | | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | | | | |InplaceDimShuffle{x,0} [id BM]
> | | | | | | | |InplaceDimShuffle{x,0} [id BM]
> | | | | | | | |Elemwise{second,no_inplace} [id BN]
> | | | | | | | |Elemwise{second,no_inplace} [id BN]
> | | | | | | | |
A_copy
[id BO] -> [id W] (inner_in_non_seqs-0)
> | | | | | | | |
*2-<TensorType(float64, (None,))>
[id BO] -> [id W] (inner_in_non_seqs-0)
> | | | | | | | |InplaceDimShuffle{x} [id BP]
> | | | | | | | |InplaceDimShuffle{x} [id BP]
> | | | | | | | |TensorConstant{1.0} [id BQ]
> | | | | | | | |TensorConstant{1.0} [id BQ]
> | | | | | | |ScalarConstant{0} [id BR]
> | | | | | | |ScalarConstant{0} [id BR]
...
@@ -277,21 +281,22 @@ def test_debugprint_nested_scans():
...
@@ -277,21 +281,22 @@ def test_debugprint_nested_scans():
> | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | |ScalarFromTensor [id BV]
> | | | | |ScalarFromTensor [id BV]
> | | | | |Subtensor{int64} [id BJ]
> | | | | |Subtensor{int64} [id BJ]
> | | | |
A_copy
[id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
> | | | |
*2-<TensorType(float64, (None,))>
[id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
> | | |ScalarConstant{1} [id BW]
> | | |ScalarConstant{1} [id BW]
> | |ScalarConstant{-1} [id BX]
> | |ScalarConstant{-1} [id BX]
> |InplaceDimShuffle{x} [id BY]
> |InplaceDimShuffle{x} [id BY]
> |<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
> |
*1-
<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CA] (inner_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CA] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
> |
*0-
<TensorType(float64, (None,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
> |
A_copy
[id CC] -> [id BO] (inner_in_non_seqs-0)"""
> |
*1-<TensorType(float64, (None,))>
[id CC] -> [id BO] (inner_in_non_seqs-0)"""
for
truth
,
out
in
zip
(
expected_output
.
split
(
"
\n
"
),
lines
):
for
truth
,
out
in
zip
(
expected_output
.
split
(
"
\n
"
),
lines
):
assert
truth
.
strip
()
==
out
.
strip
()
assert
truth
.
strip
()
==
out
.
strip
()
@config.change_flags
(
floatX
=
"float64"
)
def
test_debugprint_mitsot
():
def
test_debugprint_mitsot
():
def
fn
(
a_m2
,
a_m1
,
b_m2
,
b_m1
):
def
fn
(
a_m2
,
a_m1
,
b_m2
,
b_m1
):
return
a_m1
+
a_m2
,
b_m1
+
b_m2
return
a_m1
+
a_m2
,
b_m1
+
b_m2
...
@@ -351,11 +356,11 @@ def test_debugprint_mitsot():
...
@@ -351,11 +356,11 @@ def test_debugprint_mitsot():
for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
>Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
>Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
> |<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
> |
*1-
<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
> |<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
> |
*0-
<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
>Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1)
>Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1)
> |<TensorType(int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
> |
*3-
<TensorType(int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
> |<TensorType(int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)
> |
*2-
<TensorType(int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)
for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
>Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
>Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
...
@@ -365,6 +370,7 @@ def test_debugprint_mitsot():
...
@@ -365,6 +370,7 @@ def test_debugprint_mitsot():
assert
truth
.
strip
()
==
out
.
strip
()
assert
truth
.
strip
()
==
out
.
strip
()
@config.change_flags
(
floatX
=
"float64"
)
def
test_debugprint_mitmot
():
def
test_debugprint_mitmot
():
k
=
iscalar
(
"k"
)
k
=
iscalar
(
"k"
)
...
@@ -471,19 +477,19 @@ def test_debugprint_mitmot():
...
@@ -471,19 +477,19 @@ def test_debugprint_mitmot():
for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
>Elemwise{add,no_inplace} [id CM] (inner_out_mit_mot-0-0)
>Elemwise{add,no_inplace} [id CM] (inner_out_mit_mot-0-0)
> |Elemwise{mul} [id CN]
> |Elemwise{mul} [id CN]
> | |<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |
*2-
<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |
A_copy
[id CP] -> [id P] (inner_in_non_seqs-0)
> | |
*5-<TensorType(float64, (None,))>
[id CP] -> [id P] (inner_in_non_seqs-0)
> |<TensorType(float64, (None,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
> |
*3-
<TensorType(float64, (None,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
>Elemwise{add,no_inplace} [id CR] (inner_out_sit_sot-0)
>Elemwise{add,no_inplace} [id CR] (inner_out_sit_sot-0)
> |Elemwise{mul} [id CS]
> |Elemwise{mul} [id CS]
> | |<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |
*2-
<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |<TensorType(float64, (None,))> [id CT] -> [id Z] (inner_in_seqs-0)
> | |
*0-
<TensorType(float64, (None,))> [id CT] -> [id Z] (inner_in_seqs-0)
> |<TensorType(float64, (None,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
> |
*4-
<TensorType(float64, (None,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id CT] -> [id H] (inner_in_sit_sot-0)
> |
*0-
<TensorType(float64, (None,))> [id CT] -> [id H] (inner_in_sit_sot-0)
> |
A_copy [id CP
] -> [id P] (inner_in_non_seqs-0)
> |
*1-<TensorType(float64, (None,))> [id CW
] -> [id P] (inner_in_non_seqs-0)
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
...
@@ -540,11 +546,11 @@ def test_debugprint_compiled_fn():
...
@@ -540,11 +546,11 @@ def test_debugprint_compiled_fn():
>Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
>Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
> |TensorConstant{0} [id J]
> |TensorConstant{0} [id J]
> |Subtensor{int64, int64, int64} [id K]
> |Subtensor{int64, int64, int64} [id K]
> | |<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
> | |
*2-
<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
> | |ScalarFromTensor [id M]
> | |ScalarFromTensor [id M]
> | | |<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
> | | |
*0-
<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
> | |ScalarFromTensor [id O]
> | |ScalarFromTensor [id O]
> | | |<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
> | | |
*1-
<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
> | |ScalarConstant{0} [id Q]
> | |ScalarConstant{0} [id Q]
> |TensorConstant{1} [id R]
> |TensorConstant{1} [id R]
"""
"""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论