提交 4d0aa3f2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Return from scalar constants in `get_unique_constant_value`

上级 7a0ea76e
...@@ -1045,11 +1045,13 @@ def get_unique_constant_value(x: TensorVariable) -> Number | None: ...@@ -1045,11 +1045,13 @@ def get_unique_constant_value(x: TensorVariable) -> Number | None:
if isinstance(x, Constant): if isinstance(x, Constant):
data = x.data data = x.data
if isinstance(data, np.ndarray) and data.ndim > 0: if isinstance(data, np.ndarray) and data.size > 0:
if data.size == 1:
return data.squeeze()
flat_data = data.ravel() flat_data = data.ravel()
if flat_data.shape[0]: if (flat_data == flat_data[0]).all():
if (flat_data == flat_data[0]).all(): return flat_data[0]
return flat_data[0]
return None return None
......
...@@ -654,24 +654,22 @@ def test_debugprint_compiled_fn(): ...@@ -654,24 +654,22 @@ def test_debugprint_compiled_fn():
Inner graphs: Inner graphs:
Scan{scan_fn, while_loop=False, inplace=all} [id A] Scan{scan_fn, while_loop=False, inplace=all} [id A]
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0) ← Composite{switch(lt(0, i0), 1, 0)} [id I] (inner_out_sit_sot-0)
├─ 0 [id J] └─ Subtensor{i, j, k} [id J]
├─ Subtensor{i, j, k} [id K] ├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id K] -> [id H] (inner_in_non_seqs-0)
│ ├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0) ├─ ScalarFromTensor [id L]
│ ├─ ScalarFromTensor [id M] │ └─ *0-<Scalar(int64, shape=())> [id M] -> [id C] (inner_in_seqs-0)
│ │ └─ *0-<Scalar(int64, shape=())> [id N] -> [id C] (inner_in_seqs-0) ├─ ScalarFromTensor [id N]
│ ├─ ScalarFromTensor [id O] │ └─ *1-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_sit_sot-0)
│ │ └─ *1-<Scalar(int64, shape=())> [id P] -> [id D] (inner_in_sit_sot-0) └─ 0 [id P]
│ └─ 0 [id Q]
└─ 1 [id R] Composite{switch(lt(0, i0), 1, 0)} [id I]
← Switch [id Q] 'o0'
Composite{switch(lt(i0, i1), i2, i0)} [id I] ├─ LT [id R]
← Switch [id S] 'o0' │ ├─ 0 [id S]
├─ LT [id T] │ └─ i0 [id T]
│ ├─ i0 [id U] ├─ 1 [id U]
│ └─ i1 [id V] └─ 0 [id S]
├─ i2 [id W]
└─ i0 [id U]
""" """
output_str = debugprint(out, file="str", print_op_info=True) output_str = debugprint(out, file="str", print_op_info=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论