Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2faa56a4
提交
2faa56a4
authored
1月 11, 2026
作者:
ricardoV94
提交者:
Ricardo Vieira
1月 12, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Do not redefine DisconnectedType everytime
上级
d8b51df8
显示空白字符变更
内嵌
并排
正在显示
16 个修改的文件
包含
72 行增加
和
61 行删除
+72
-61
creating_an_op.rst
doc/extending/creating_an_op.rst
+3
-3
breakpoint.py
pytensor/breakpoint.py
+2
-2
raise_op.py
pytensor/raise_op.py
+5
-2
basic.py
pytensor/scalar/basic.py
+3
-3
op.py
pytensor/scan/op.py
+15
-8
basic.py
pytensor/sparse/basic.py
+5
-5
basic.py
pytensor/tensor/basic.py
+6
-6
extra_ops.py
pytensor/tensor/extra_ops.py
+1
-2
fft.py
pytensor/tensor/fft.py
+3
-3
nlinalg.py
pytensor/tensor/nlinalg.py
+3
-3
reshape.py
pytensor/tensor/reshape.py
+2
-2
shape.py
pytensor/tensor/shape.py
+6
-5
conv.py
pytensor/tensor/signal/conv.py
+2
-2
slinalg.py
pytensor/tensor/slinalg.py
+2
-2
subtensor.py
pytensor/tensor/subtensor.py
+12
-11
type_other.py
pytensor/tensor/type_other.py
+2
-2
没有找到文件。
doc/extending/creating_an_op.rst
浏览文件 @
2faa56a4
...
@@ -495,7 +495,7 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o
...
@@ -495,7 +495,7 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o
from pytensor.graph.op import Op
from pytensor.graph.op import Op
from pytensor.graph.basic import Apply
from pytensor.graph.basic import Apply
from pytensor.gradient import DisconnectedType
from pytensor.gradient import DisconnectedType
, disconnected_type
class TransposeAndSumOp(Op):
class TransposeAndSumOp(Op):
__props__ = ()
__props__ = ()
...
@@ -539,13 +539,13 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o
...
@@ -539,13 +539,13 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o
out1_grad, out2_grad = output_grads
out1_grad, out2_grad = output_grads
if isinstance(out1_grad.type, DisconnectedType):
if isinstance(out1_grad.type, DisconnectedType):
x_grad =
DisconnectedType()
()
x_grad =
disconnected_type
()
else:
else:
# Transpose the last two dimensions of the output gradient
# Transpose the last two dimensions of the output gradient
x_grad = pt.swapaxes(out1_grad, -1, -2)
x_grad = pt.swapaxes(out1_grad, -1, -2)
if isinstance(out2_grad.type, DisconnectedType):
if isinstance(out2_grad.type, DisconnectedType):
y_grad =
DisconnectedType()
()
y_grad =
disconnected_type
()
else:
else:
# Broadcast the output gradient to the same shape as y
# Broadcast the output gradient to the same shape as y
y_grad = pt.broadcast_to(pt.expand_dims(out2_grad, -1), y.shape)
y_grad = pt.broadcast_to(pt.expand_dims(out2_grad, -1), y.shape)
...
...
pytensor/breakpoint.py
浏览文件 @
2faa56a4
import
numpy
as
np
import
numpy
as
np
from
pytensor.gradient
import
DisconnectedT
ype
from
pytensor.gradient
import
disconnected_t
ype
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.tensor.basic
import
as_tensor_variable
from
pytensor.tensor.basic
import
as_tensor_variable
...
@@ -142,7 +142,7 @@ class PdbBreakpoint(Op):
...
@@ -142,7 +142,7 @@ class PdbBreakpoint(Op):
output_storage
[
i
][
0
]
=
inputs
[
i
+
1
]
output_storage
[
i
][
0
]
=
inputs
[
i
+
1
]
def
grad
(
self
,
inputs
,
output_gradients
):
def
grad
(
self
,
inputs
,
output_gradients
):
return
[
DisconnectedType
()
(),
*
output_gradients
]
return
[
disconnected_type
(),
*
output_gradients
]
def
infer_shape
(
self
,
fgraph
,
inputs
,
input_shapes
):
def
infer_shape
(
self
,
fgraph
,
inputs
,
input_shapes
):
# Return the shape of every input but the condition (first input)
# Return the shape of every input but the condition (first input)
...
...
pytensor/raise_op.py
浏览文件 @
2faa56a4
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
textwrap
import
indent
from
textwrap
import
indent
from
pytensor.gradient
import
DisconnectedT
ype
from
pytensor.gradient
import
disconnected_t
ype
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.op
import
COp
...
@@ -89,7 +89,10 @@ class CheckAndRaise(COp):
...
@@ -89,7 +89,10 @@ class CheckAndRaise(COp):
raise
self
.
exc_type
(
self
.
msg
)
raise
self
.
exc_type
(
self
.
msg
)
def
grad
(
self
,
input
,
output_gradients
):
def
grad
(
self
,
input
,
output_gradients
):
return
output_gradients
+
[
DisconnectedType
()()]
*
(
len
(
input
)
-
1
)
return
[
*
output_gradients
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
input
)
-
1
)),
]
def
connection_pattern
(
self
,
node
):
def
connection_pattern
(
self
,
node
):
return
[[
1
]]
+
[[
0
]]
*
(
len
(
node
.
inputs
)
-
1
)
return
[[
1
]]
+
[[
0
]]
*
(
len
(
node
.
inputs
)
-
1
)
...
...
pytensor/scalar/basic.py
浏览文件 @
2faa56a4
...
@@ -22,7 +22,7 @@ import numpy as np
...
@@ -22,7 +22,7 @@ import numpy as np
import
pytensor
import
pytensor
from
pytensor
import
printing
from
pytensor
import
printing
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.gradient
import
DisconnectedT
ype
,
grad_undefined
from
pytensor.gradient
import
disconnected_t
ype
,
grad_undefined
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
clone
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
clone
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
HasInnerGraph
from
pytensor.graph.op
import
HasInnerGraph
...
@@ -2426,13 +2426,13 @@ class Second(BinaryScalarOp):
...
@@ -2426,13 +2426,13 @@ class Second(BinaryScalarOp):
(
gz
,)
=
gout
(
gz
,)
=
gout
if
y
.
type
in
continuous_types
:
if
y
.
type
in
continuous_types
:
# x is disconnected because the elements of x are not used
# x is disconnected because the elements of x are not used
return
DisconnectedType
()
(),
gz
return
disconnected_type
(),
gz
else
:
else
:
# when y is discrete, we assume the function can be extended
# when y is discrete, we assume the function can be extended
# to deal with real-valued inputs by rounding them to the
# to deal with real-valued inputs by rounding them to the
# nearest integer. f(x+eps) thus equals f(x) so the gradient
# nearest integer. f(x+eps) thus equals f(x) so the gradient
# is zero, not disconnected or undefined
# is zero, not disconnected or undefined
return
DisconnectedType
()
(),
y
.
zeros_like
(
dtype
=
config
.
floatX
)
return
disconnected_type
(),
y
.
zeros_like
(
dtype
=
config
.
floatX
)
second
=
Second
(
name
=
"second"
)
second
=
Second
(
name
=
"second"
)
...
...
pytensor/scan/op.py
浏览文件 @
2faa56a4
...
@@ -63,7 +63,14 @@ from pytensor.compile.io import In, Out
...
@@ -63,7 +63,14 @@ from pytensor.compile.io import In, Out
from
pytensor.compile.mode
import
Mode
,
get_mode
from
pytensor.compile.mode
import
Mode
,
get_mode
from
pytensor.compile.profiling
import
register_profiler_printer
from
pytensor.compile.profiling
import
register_profiler_printer
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.gradient
import
DisconnectedType
,
NullType
,
Rop
,
grad
,
grad_undefined
from
pytensor.gradient
import
(
DisconnectedType
,
NullType
,
Rop
,
disconnected_type
,
grad
,
grad_undefined
,
)
from
pytensor.graph.basic
import
(
from
pytensor.graph.basic
import
(
Apply
,
Apply
,
Variable
,
Variable
,
...
@@ -3073,7 +3080,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3073,7 +3080,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
)
outputs
=
local_op
(
*
outer_inputs
,
return_list
=
True
)
outputs
=
local_op
(
*
outer_inputs
,
return_list
=
True
)
# Re-order the gradients correctly
# Re-order the gradients correctly
gradients
=
[
DisconnectedType
()()]
gradients
=
[
disconnected_type
()]
# n_steps is disconnected
offset
=
info
.
n_mit_mot
+
info
.
n_mit_sot
+
info
.
n_sit_sot
+
n_sitsot_outs
offset
=
info
.
n_mit_mot
+
info
.
n_mit_sot
+
info
.
n_sit_sot
+
n_sitsot_outs
for
p
,
(
x
,
t
)
in
enumerate
(
for
p
,
(
x
,
t
)
in
enumerate
(
...
@@ -3098,7 +3105,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3098,7 +3105,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else
:
else
:
gradients
.
append
(
x
[::
-
1
])
gradients
.
append
(
x
[::
-
1
])
elif
t
==
"disconnected"
:
elif
t
==
"disconnected"
:
gradients
.
append
(
DisconnectedType
()
())
gradients
.
append
(
disconnected_type
())
elif
t
==
"through_untraced"
:
elif
t
==
"through_untraced"
:
gradients
.
append
(
gradients
.
append
(
grad_undefined
(
grad_undefined
(
...
@@ -3126,7 +3133,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3126,7 +3133,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else
:
else
:
gradients
.
append
(
x
[::
-
1
])
gradients
.
append
(
x
[::
-
1
])
elif
t
==
"disconnected"
:
elif
t
==
"disconnected"
:
gradients
.
append
(
DisconnectedType
()
())
gradients
.
append
(
disconnected_type
())
elif
t
==
"through_untraced"
:
elif
t
==
"through_untraced"
:
gradients
.
append
(
gradients
.
append
(
grad_undefined
(
grad_undefined
(
...
@@ -3149,7 +3156,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3149,7 +3156,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if
not
isinstance
(
dC_dout
.
type
,
DisconnectedType
)
and
connected
:
if
not
isinstance
(
dC_dout
.
type
,
DisconnectedType
)
and
connected
:
disconnected
=
False
disconnected
=
False
if
disconnected
:
if
disconnected
:
gradients
.
append
(
DisconnectedType
()
())
gradients
.
append
(
disconnected_type
())
else
:
else
:
gradients
.
append
(
gradients
.
append
(
grad_undefined
(
grad_undefined
(
...
@@ -3157,7 +3164,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3157,7 +3164,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
)
)
)
gradients
+=
[
DisconnectedType
()()
for
_
in
range
(
info
.
n_nit_sot
)]
gradients
.
extend
(
disconnected_type
()
for
_
in
range
(
info
.
n_nit_sot
))
begin
=
end
begin
=
end
end
=
begin
+
n_sitsot_outs
end
=
begin
+
n_sitsot_outs
...
@@ -3167,7 +3174,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3167,7 +3174,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if
t
==
"connected"
:
if
t
==
"connected"
:
gradients
.
append
(
x
[
-
1
])
gradients
.
append
(
x
[
-
1
])
elif
t
==
"disconnected"
:
elif
t
==
"disconnected"
:
gradients
.
append
(
DisconnectedType
()
())
gradients
.
append
(
disconnected_type
())
elif
t
==
"through_untraced"
:
elif
t
==
"through_untraced"
:
gradients
.
append
(
gradients
.
append
(
grad_undefined
(
grad_undefined
(
...
@@ -3195,7 +3202,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3195,7 +3202,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
):
):
disconnected
=
False
disconnected
=
False
if
disconnected
:
if
disconnected
:
gradients
[
idx
]
=
DisconnectedType
()
()
gradients
[
idx
]
=
disconnected_type
()
return
gradients
return
gradients
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
...
...
pytensor/sparse/basic.py
浏览文件 @
2faa56a4
...
@@ -18,7 +18,7 @@ import pytensor
...
@@ -18,7 +18,7 @@ import pytensor
from
pytensor
import
_as_symbolic
,
as_symbolic
from
pytensor
import
_as_symbolic
,
as_symbolic
from
pytensor
import
scalar
as
ps
from
pytensor
import
scalar
as
ps
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.gradient
import
DisconnectedType
,
grad_undefined
from
pytensor.gradient
import
DisconnectedType
,
disconnected_type
,
grad_undefined
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.link.c.type
import
generic
from
pytensor.link.c.type
import
generic
...
@@ -480,9 +480,9 @@ class CSM(Op):
...
@@ -480,9 +480,9 @@ class CSM(Op):
)
)
return
[
return
[
g_data
,
g_data
,
DisconnectedType
()
(),
disconnected_type
(),
DisconnectedType
()
(),
disconnected_type
(),
DisconnectedType
()
(),
disconnected_type
(),
]
]
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
...
@@ -1940,7 +1940,7 @@ class ConstructSparseFromList(Op):
...
@@ -1940,7 +1940,7 @@ class ConstructSparseFromList(Op):
gx
=
g_output
gx
=
g_output
gy
=
pytensor
.
tensor
.
subtensor
.
advanced_subtensor1
(
g_output
,
*
idx_list
)
gy
=
pytensor
.
tensor
.
subtensor
.
advanced_subtensor1
(
g_output
,
*
idx_list
)
return
[
gx
,
gy
]
+
[
DisconnectedType
()()]
*
len
(
idx_list
)
return
[
gx
,
gy
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
idx_list
)))]
construct_sparse_from_list
=
ConstructSparseFromList
()
construct_sparse_from_list
=
ConstructSparseFromList
()
pytensor/tensor/basic.py
浏览文件 @
2faa56a4
...
@@ -22,7 +22,7 @@ import pytensor.scalar.sharedvar
...
@@ -22,7 +22,7 @@ import pytensor.scalar.sharedvar
from
pytensor
import
config
,
printing
from
pytensor
import
config
,
printing
from
pytensor
import
scalar
as
ps
from
pytensor
import
scalar
as
ps
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.gradient
import
DisconnectedType
,
grad_undefined
from
pytensor.gradient
import
DisconnectedType
,
disconnected_type
,
grad_undefined
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
equal_computations
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
equal_computations
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.fg
import
FunctionGraph
,
Output
...
@@ -1738,7 +1738,7 @@ class Alloc(COp):
...
@@ -1738,7 +1738,7 @@ class Alloc(COp):
# the inputs that specify the shape. If you grow the
# the inputs that specify the shape. If you grow the
# shape by epsilon, the existing elements do not
# shape by epsilon, the existing elements do not
# change.
# change.
return
[
gx
]
+
[
DisconnectedType
()()
for
i
in
inputs
[
1
:]
]
return
[
gx
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
inputs
)
-
1
))
]
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
if
eval_points
[
0
]
is
None
:
if
eval_points
[
0
]
is
None
:
...
@@ -2277,7 +2277,7 @@ class Split(COp):
...
@@ -2277,7 +2277,7 @@ class Split(COp):
return
[
return
[
join
(
axis
,
*
new_g_outputs
),
join
(
axis
,
*
new_g_outputs
),
grad_undefined
(
self
,
1
,
axis
),
grad_undefined
(
self
,
1
,
axis
),
DisconnectedType
()
(),
disconnected_type
(),
]
]
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
...
@@ -3340,14 +3340,14 @@ class ARange(COp):
...
@@ -3340,14 +3340,14 @@ class ARange(COp):
if
self
.
dtype
in
discrete_dtypes
:
if
self
.
dtype
in
discrete_dtypes
:
return
[
return
[
start
.
zeros_like
(
dtype
=
config
.
floatX
),
start
.
zeros_like
(
dtype
=
config
.
floatX
),
DisconnectedType
()
(),
disconnected_type
(),
step
.
zeros_like
(
dtype
=
config
.
floatX
),
step
.
zeros_like
(
dtype
=
config
.
floatX
),
]
]
else
:
else
:
num_steps_taken
=
outputs
[
0
]
.
shape
[
0
]
num_steps_taken
=
outputs
[
0
]
.
shape
[
0
]
return
[
return
[
gz
.
sum
(),
gz
.
sum
(),
DisconnectedType
()
(),
disconnected_type
(),
(
gz
*
arange
(
num_steps_taken
,
dtype
=
self
.
dtype
))
.
sum
(),
(
gz
*
arange
(
num_steps_taken
,
dtype
=
self
.
dtype
))
.
sum
(),
]
]
...
@@ -4374,7 +4374,7 @@ class AllocEmpty(COp):
...
@@ -4374,7 +4374,7 @@ class AllocEmpty(COp):
return
[[
False
]
for
i
in
node
.
inputs
]
return
[[
False
]
for
i
in
node
.
inputs
]
def
grad
(
self
,
inputs
,
grads
):
def
grad
(
self
,
inputs
,
grads
):
return
[
DisconnectedType
()()
for
i
in
inputs
]
return
[
disconnected_type
()
for
_
in
range
(
len
(
inputs
))
]
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
return
[
zeros
(
inputs
,
self
.
dtype
)]
return
[
zeros
(
inputs
,
self
.
dtype
)]
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
2faa56a4
...
@@ -8,7 +8,6 @@ from numpy.lib.array_utils import normalize_axis_index
...
@@ -8,7 +8,6 @@ from numpy.lib.array_utils import normalize_axis_index
import
pytensor
import
pytensor
import
pytensor.scalar.basic
as
ps
import
pytensor.scalar.basic
as
ps
from
pytensor.gradient
import
(
from
pytensor.gradient
import
(
DisconnectedType
,
_float_zeros_like
,
_float_zeros_like
,
disconnected_type
,
disconnected_type
,
grad_undefined
,
grad_undefined
,
...
@@ -716,7 +715,7 @@ class Repeat(Op):
...
@@ -716,7 +715,7 @@ class Repeat(Op):
gx_transpose
=
ptb
.
zeros_like
(
x_transpose
)[
repeated_arange
]
.
inc
(
gz_transpose
)
gx_transpose
=
ptb
.
zeros_like
(
x_transpose
)[
repeated_arange
]
.
inc
(
gz_transpose
)
gx
=
ptb
.
moveaxis
(
gx_transpose
,
0
,
axis
)
gx
=
ptb
.
moveaxis
(
gx_transpose
,
0
,
axis
)
return
[
gx
,
DisconnectedType
()
()]
return
[
gx
,
disconnected_type
()]
def
infer_shape
(
self
,
fgraph
,
node
,
ins_shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
ins_shapes
):
i0_shapes
=
ins_shapes
[
0
]
i0_shapes
=
ins_shapes
[
0
]
...
...
pytensor/tensor/fft.py
浏览文件 @
2faa56a4
import
numpy
as
np
import
numpy
as
np
from
pytensor.gradient
import
DisconnectedT
ype
from
pytensor.gradient
import
disconnected_t
ype
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.tensor.basic
import
as_tensor_variable
from
pytensor.tensor.basic
import
as_tensor_variable
...
@@ -59,7 +59,7 @@ class RFFTOp(Op):
...
@@ -59,7 +59,7 @@ class RFFTOp(Op):
+
[
slice
(
None
)]
+
[
slice
(
None
)]
)
)
gout
=
set_subtensor
(
gout
[
idx
],
gout
[
idx
]
*
0.5
)
gout
=
set_subtensor
(
gout
[
idx
],
gout
[
idx
]
*
0.5
)
return
[
irfft_op
(
gout
,
s
),
DisconnectedType
()
()]
return
[
irfft_op
(
gout
,
s
),
disconnected_type
()]
def
connection_pattern
(
self
,
node
):
def
connection_pattern
(
self
,
node
):
# Specify that shape input parameter has no connection to graph and gradients.
# Specify that shape input parameter has no connection to graph and gradients.
...
@@ -121,7 +121,7 @@ class IRFFTOp(Op):
...
@@ -121,7 +121,7 @@ class IRFFTOp(Op):
+
[
slice
(
None
)]
+
[
slice
(
None
)]
)
)
gf
=
set_subtensor
(
gf
[
idx
],
gf
[
idx
]
*
2
)
gf
=
set_subtensor
(
gf
[
idx
],
gf
[
idx
]
*
2
)
return
[
gf
,
DisconnectedType
()
()]
return
[
gf
,
disconnected_type
()]
def
connection_pattern
(
self
,
node
):
def
connection_pattern
(
self
,
node
):
# Specify that shape input parameter has no connection to graph and gradients.
# Specify that shape input parameter has no connection to graph and gradients.
...
...
pytensor/tensor/nlinalg.py
浏览文件 @
2faa56a4
...
@@ -8,7 +8,7 @@ from numpy.lib.array_utils import normalize_axis_tuple
...
@@ -8,7 +8,7 @@ from numpy.lib.array_utils import normalize_axis_tuple
from
pytensor
import
scalar
as
ps
from
pytensor
import
scalar
as
ps
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.gradient
import
DisconnectedType
from
pytensor.gradient
import
DisconnectedType
,
disconnected_type
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.tensor
import
TensorLike
from
pytensor.tensor
import
TensorLike
...
@@ -652,8 +652,8 @@ class SVD(Op):
...
@@ -652,8 +652,8 @@ class SVD(Op):
]
]
if
all
(
is_disconnected
):
if
all
(
is_disconnected
):
# This should never actually be reached by Pytensor -- the SVD Op should be pruned from the gradient
# This should never actually be reached by Pytensor -- the SVD Op should be pruned from the gradient
# graph if its fully disconnected. It is included for completeness.
# graph if it
'
s fully disconnected. It is included for completeness.
return
[
DisconnectedType
()
()]
# pragma: no cover
return
[
disconnected_type
()]
# pragma: no cover
elif
is_disconnected
==
[
True
,
False
,
True
]:
elif
is_disconnected
==
[
True
,
False
,
True
]:
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without
...
...
pytensor/tensor/reshape.py
浏览文件 @
2faa56a4
...
@@ -6,7 +6,7 @@ import numpy as np
...
@@ -6,7 +6,7 @@ import numpy as np
from
numpy.lib._array_utils_impl
import
normalize_axis_index
,
normalize_axis_tuple
from
numpy.lib._array_utils_impl
import
normalize_axis_index
,
normalize_axis_tuple
from
pytensor
import
Variable
from
pytensor
import
Variable
from
pytensor.gradient
import
DisconnectedT
ype
from
pytensor.gradient
import
disconnected_t
ype
from
pytensor.graph
import
Apply
from
pytensor.graph
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.replace
import
_vectorize_node
...
@@ -217,7 +217,7 @@ class SplitDims(Op):
...
@@ -217,7 +217,7 @@ class SplitDims(Op):
n_axes
=
g_out
.
ndim
-
x
.
ndim
+
1
n_axes
=
g_out
.
ndim
-
x
.
ndim
+
1
axis_range
=
list
(
range
(
self
.
axis
,
self
.
axis
+
n_axes
))
axis_range
=
list
(
range
(
self
.
axis
,
self
.
axis
+
n_axes
))
return
[
join_dims
(
g_out
,
axis
=
axis_range
),
DisconnectedType
()
()]
return
[
join_dims
(
g_out
,
axis
=
axis_range
),
disconnected_type
()]
@_vectorize_node.register
(
SplitDims
)
@_vectorize_node.register
(
SplitDims
)
...
...
pytensor/tensor/shape.py
浏览文件 @
2faa56a4
...
@@ -10,7 +10,7 @@ import numpy as np
...
@@ -10,7 +10,7 @@ import numpy as np
from
numpy.lib.array_utils
import
normalize_axis_tuple
from
numpy.lib.array_utils
import
normalize_axis_tuple
import
pytensor
import
pytensor
from
pytensor.gradient
import
DisconnectedT
ype
from
pytensor.gradient
import
disconnected_t
ype
from
pytensor.graph
import
Op
from
pytensor.graph
import
Op
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.replace
import
_vectorize_node
...
@@ -103,7 +103,7 @@ class Shape(COp):
...
@@ -103,7 +103,7 @@ class Shape(COp):
# the elements of the tensor variable do not participate
# the elements of the tensor variable do not participate
# in the computation of the shape, so they are not really
# in the computation of the shape, so they are not really
# part of the graph
# part of the graph
return
[
pytensor
.
gradient
.
DisconnectedType
()
()]
return
[
disconnected_type
()]
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
return
[
None
]
return
[
None
]
...
@@ -474,8 +474,9 @@ class SpecifyShape(COp):
...
@@ -474,8 +474,9 @@ class SpecifyShape(COp):
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
_x
,
*
shape
=
inp
_x
,
*
shape
=
inp
(
gz
,)
=
grads
(
gz
,)
=
grads
return
[
specify_shape
(
gz
,
shape
)]
+
[
return
[
pytensor
.
gradient
.
DisconnectedType
()()
for
_
in
range
(
len
(
shape
))
specify_shape
(
gz
,
shape
),
*
(
disconnected_type
()
for
_
in
range
(
len
(
shape
))),
]
]
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
...
@@ -725,7 +726,7 @@ class Reshape(COp):
...
@@ -725,7 +726,7 @@ class Reshape(COp):
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
x
,
_shp
=
inp
x
,
_shp
=
inp
(
g_out
,)
=
grads
(
g_out
,)
=
grads
return
[
reshape
(
g_out
,
shape
(
x
),
ndim
=
x
.
ndim
),
DisconnectedType
()
()]
return
[
reshape
(
g_out
,
shape
(
x
),
ndim
=
x
.
ndim
),
disconnected_type
()]
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
if
eval_points
[
0
]
is
None
:
if
eval_points
[
0
]
is
None
:
...
...
pytensor/tensor/signal/conv.py
浏览文件 @
2faa56a4
...
@@ -5,7 +5,7 @@ import numpy as np
...
@@ -5,7 +5,7 @@ import numpy as np
from
numpy
import
convolve
as
numpy_convolve
from
numpy
import
convolve
as
numpy_convolve
from
scipy.signal
import
convolve
as
scipy_convolve
from
scipy.signal
import
convolve
as
scipy_convolve
from
pytensor.gradient
import
DisconnectedT
ype
from
pytensor.gradient
import
disconnected_t
ype
from
pytensor.graph
import
Apply
,
Constant
from
pytensor.graph
import
Apply
,
Constant
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.op
import
COp
...
@@ -109,7 +109,7 @@ class AbstractConvolveNd:
...
@@ -109,7 +109,7 @@ class AbstractConvolveNd:
return
[
return
[
self
(
grad
,
flip
(
in2
),
full_mode_in1_bar
),
self
(
grad
,
flip
(
in2
),
full_mode_in1_bar
),
self
(
grad
,
flip
(
in1
),
full_mode_in2_bar
),
self
(
grad
,
flip
(
in1
),
full_mode_in2_bar
),
DisconnectedType
()
(),
disconnected_type
(),
]
]
...
...
pytensor/tensor/slinalg.py
浏览文件 @
2faa56a4
...
@@ -11,7 +11,7 @@ from scipy.linalg import get_lapack_funcs
...
@@ -11,7 +11,7 @@ from scipy.linalg import get_lapack_funcs
import
pytensor
import
pytensor
from
pytensor
import
ifelse
from
pytensor
import
ifelse
from
pytensor
import
tensor
as
pt
from
pytensor
import
tensor
as
pt
from
pytensor.gradient
import
DisconnectedType
from
pytensor.gradient
import
DisconnectedType
,
disconnected_type
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.raise_op
import
Assert
,
CheckAndRaise
from
pytensor.raise_op
import
Assert
,
CheckAndRaise
...
@@ -1966,7 +1966,7 @@ class QR(Op):
...
@@ -1966,7 +1966,7 @@ class QR(Op):
]
]
if
all
(
is_disconnected
):
if
all
(
is_disconnected
):
# This should never be reached by Pytensor
# This should never be reached by Pytensor
return
[
DisconnectedType
()
()]
# pragma: no cover
return
[
disconnected_type
()]
# pragma: no cover
for
disconnected
,
output_grad
,
output
in
zip
(
for
disconnected
,
output_grad
,
output
in
zip
(
is_disconnected
,
output_grads
,
[
Q
,
R
],
strict
=
True
is_disconnected
,
output_grads
,
[
Q
,
R
],
strict
=
True
...
...
pytensor/tensor/subtensor.py
浏览文件 @
2faa56a4
...
@@ -11,7 +11,7 @@ from numpy.lib.array_utils import normalize_axis_tuple
...
@@ -11,7 +11,7 @@ from numpy.lib.array_utils import normalize_axis_tuple
import
pytensor
import
pytensor
from
pytensor
import
scalar
as
ps
from
pytensor
import
scalar
as
ps
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.gradient
import
DisconnectedT
ype
from
pytensor.gradient
import
disconnected_t
ype
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.replace
import
_vectorize_node
...
@@ -988,7 +988,7 @@ class Subtensor(COp):
...
@@ -988,7 +988,7 @@ class Subtensor(COp):
# set subtensor here at:
# set subtensor here at:
# pytensor/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor()
# pytensor/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor()
first
=
IncSubtensor
(
self
.
idx_list
)(
x
.
zeros_like
(),
gz
,
*
rest
)
first
=
IncSubtensor
(
self
.
idx_list
)(
x
.
zeros_like
(),
gz
,
*
rest
)
return
[
first
]
+
[
DisconnectedType
()()]
*
len
(
rest
)
return
[
first
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
rest
)))]
def
connection_pattern
(
self
,
node
):
def
connection_pattern
(
self
,
node
):
rval
=
[[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
1
:])]
rval
=
[[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
1
:])]
...
@@ -2023,7 +2023,7 @@ class IncSubtensor(COp):
...
@@ -2023,7 +2023,7 @@ class IncSubtensor(COp):
gy
=
Subtensor
(
idx_list
=
self
.
idx_list
)(
g_output
,
*
idx_list
)
gy
=
Subtensor
(
idx_list
=
self
.
idx_list
)(
g_output
,
*
idx_list
)
gy
=
_sum_grad_over_bcasted_dims
(
y
,
gy
)
gy
=
_sum_grad_over_bcasted_dims
(
y
,
gy
)
return
[
gx
,
gy
]
+
[
DisconnectedType
()()]
*
len
(
idx_list
)
return
[
gx
,
gy
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
idx_list
)))]
class
IncSubtensorPrinter
(
SubtensorPrinter
):
class
IncSubtensorPrinter
(
SubtensorPrinter
):
...
@@ -2135,7 +2135,7 @@ class AdvancedSubtensor1(COp):
...
@@ -2135,7 +2135,7 @@ class AdvancedSubtensor1(COp):
" from a tensor with ndim != 2. ndim is "
+
str
(
x
.
type
.
ndim
)
" from a tensor with ndim != 2. ndim is "
+
str
(
x
.
type
.
ndim
)
)
)
rval1
=
[
pytensor
.
sparse
.
construct_sparse_from_list
(
x
,
gz
,
ilist
)]
rval1
=
pytensor
.
sparse
.
construct_sparse_from_list
(
x
,
gz
,
ilist
)
else
:
else
:
if
x
.
dtype
in
discrete_dtypes
:
if
x
.
dtype
in
discrete_dtypes
:
# The output dtype is the same as x
# The output dtype is the same as x
...
@@ -2144,8 +2144,8 @@ class AdvancedSubtensor1(COp):
...
@@ -2144,8 +2144,8 @@ class AdvancedSubtensor1(COp):
raise
NotImplementedError
(
"No support for complex grad yet"
)
raise
NotImplementedError
(
"No support for complex grad yet"
)
else
:
else
:
gx
=
x
.
zeros_like
()
gx
=
x
.
zeros_like
()
rval1
=
[
advanced_inc_subtensor1
(
gx
,
gz
,
ilist
)]
rval1
=
advanced_inc_subtensor1
(
gx
,
gz
,
ilist
)
return
rval1
+
[
DisconnectedType
()()]
*
(
len
(
inputs
)
-
1
)
return
[
rval1
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
inputs
)
-
1
))]
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
if
eval_points
[
0
]
is
None
:
if
eval_points
[
0
]
is
None
:
...
@@ -2519,7 +2519,7 @@ class AdvancedIncSubtensor1(COp):
...
@@ -2519,7 +2519,7 @@ class AdvancedIncSubtensor1(COp):
gy
=
advanced_subtensor1
(
g_output
,
idx_list
)
gy
=
advanced_subtensor1
(
g_output
,
idx_list
)
gy
=
_sum_grad_over_bcasted_dims
(
y
,
gy
)
gy
=
_sum_grad_over_bcasted_dims
(
y
,
gy
)
return
[
gx
,
gy
,
DisconnectedType
()
()]
return
[
gx
,
gy
,
disconnected_type
()]
advanced_inc_subtensor1
=
AdvancedIncSubtensor1
()
advanced_inc_subtensor1
=
AdvancedIncSubtensor1
()
...
@@ -2771,9 +2771,10 @@ class AdvancedSubtensor(Op):
...
@@ -2771,9 +2771,10 @@ class AdvancedSubtensor(Op):
else
:
else
:
gx
=
x
.
zeros_like
()
gx
=
x
.
zeros_like
()
rest
=
inputs
[
1
:]
rest
=
inputs
[
1
:]
return
[
advanced_inc_subtensor
(
gx
,
gz
,
*
rest
)]
+
[
DisconnectedType
()()]
*
len
(
return
[
rest
advanced_inc_subtensor
(
gx
,
gz
,
*
rest
),
)
*
(
disconnected_type
()
for
_
in
range
(
len
(
rest
))),
]
@staticmethod
@staticmethod
def
non_contiguous_adv_indexing
(
node
:
Apply
)
->
bool
:
def
non_contiguous_adv_indexing
(
node
:
Apply
)
->
bool
:
...
@@ -2933,7 +2934,7 @@ class AdvancedIncSubtensor(Op):
...
@@ -2933,7 +2934,7 @@ class AdvancedIncSubtensor(Op):
# Make sure to sum gy over the dimensions of y that have been
# Make sure to sum gy over the dimensions of y that have been
# added or broadcasted
# added or broadcasted
gy
=
_sum_grad_over_bcasted_dims
(
y
,
gy
)
gy
=
_sum_grad_over_bcasted_dims
(
y
,
gy
)
return
[
gx
,
gy
]
+
[
DisconnectedType
()()
for
_
in
idxs
]
return
[
gx
,
gy
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
idxs
)))
]
@staticmethod
@staticmethod
def
non_contiguous_adv_indexing
(
node
:
Apply
)
->
bool
:
def
non_contiguous_adv_indexing
(
node
:
Apply
)
->
bool
:
...
...
pytensor/tensor/type_other.py
浏览文件 @
2faa56a4
...
@@ -6,7 +6,7 @@ import numpy as np
...
@@ -6,7 +6,7 @@ import numpy as np
import
pytensor
import
pytensor
from
pytensor
import
_as_symbolic
from
pytensor
import
_as_symbolic
from
pytensor.gradient
import
DisconnectedT
ype
from
pytensor.gradient
import
disconnected_t
ype
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.link.c.type
import
Generic
,
Type
from
pytensor.link.c.type
import
Generic
,
Type
...
@@ -44,7 +44,7 @@ class MakeSlice(Op):
...
@@ -44,7 +44,7 @@ class MakeSlice(Op):
out
[
0
]
=
slice
(
*
inp
)
out
[
0
]
=
slice
(
*
inp
)
def
grad
(
self
,
inputs
,
grads
):
def
grad
(
self
,
inputs
,
grads
):
return
[
DisconnectedType
()()
for
i
in
inputs
]
return
[
disconnected_type
()
for
_
in
range
(
len
(
inputs
))
]
make_slice
=
MakeSlice
()
make_slice
=
MakeSlice
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论