Unverified 提交 46fdc58e authored 作者: Ian Schweer's avatar Ian Schweer 提交者: GitHub

Add torch implementation of IfElse (#974)

上级 8a6e407e
...@@ -9,6 +9,7 @@ from pytensor.compile import PYTORCH ...@@ -9,6 +9,7 @@ from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
...@@ -153,6 +154,19 @@ def pytorch_funcify_MakeVector(op, **kwargs): ...@@ -153,6 +154,19 @@ def pytorch_funcify_MakeVector(op, **kwargs):
return makevector return makevector
@pytorch_funcify.register(IfElse)
def pytorch_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
def ifelse(cond, *true_and_false, n_outs=n_outs):
if cond:
return true_and_false[:n_outs]
else:
return true_and_false[n_outs:]
return ifelse
@pytorch_funcify.register(OpFromGraph) @pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs): def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None) kwargs.pop("storage_map", None)
......
...@@ -13,6 +13,7 @@ from pytensor.configdefaults import config ...@@ -13,6 +13,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrices, matrix, scalar, vector from pytensor.tensor.type import matrices, matrix, scalar, vector
...@@ -304,6 +305,23 @@ def test_pytorch_MakeVector(): ...@@ -304,6 +305,23 @@ def test_pytorch_MakeVector():
compare_pytorch_and_py(x_fg, []) compare_pytorch_and_py(x_fg, [])
def test_pytorch_ifelse():
p1_vals = np.r_[1, 2, 3]
p2_vals = np.r_[-1, -2, -3]
a = scalar("a")
x = ifelse(a < 0.5, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
x_fg = FunctionGraph([a], x)
compare_pytorch_and_py(x_fg, np.array([0.2], dtype=config.floatX))
a = scalar("a")
x = ifelse(a < 0.4, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
x_fg = FunctionGraph([a], x)
compare_pytorch_and_py(x_fg, np.array([0.5], dtype=config.floatX))
def test_pytorch_OpFromGraph(): def test_pytorch_OpFromGraph():
x, y, z = matrices("xyz") x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y]) ofg_1 = OpFromGraph([x, y], [x + y])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论