Unverified 提交 46fdc58e authored 作者: Ian Schweer's avatar Ian Schweer 提交者: GitHub

Add torch implementation of IfElse (#974)

上级 8a6e407e
......@@ -9,6 +9,7 @@ from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import (
......@@ -153,6 +154,19 @@ def pytorch_funcify_MakeVector(op, **kwargs):
return makevector
@pytorch_funcify.register(IfElse)
def pytorch_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
def ifelse(cond, *true_and_false, n_outs=n_outs):
if cond:
return true_and_false[:n_outs]
else:
return true_and_false[n_outs:]
return ifelse
@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)
......
......@@ -13,6 +13,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrices, matrix, scalar, vector
......@@ -304,6 +305,23 @@ def test_pytorch_MakeVector():
compare_pytorch_and_py(x_fg, [])
def test_pytorch_ifelse():
p1_vals = np.r_[1, 2, 3]
p2_vals = np.r_[-1, -2, -3]
a = scalar("a")
x = ifelse(a < 0.5, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
x_fg = FunctionGraph([a], x)
compare_pytorch_and_py(x_fg, np.array([0.2], dtype=config.floatX))
a = scalar("a")
x = ifelse(a < 0.4, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
x_fg = FunctionGraph([a], x)
compare_pytorch_and_py(x_fg, np.array([0.5], dtype=config.floatX))
def test_pytorch_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论