提交 2f4cbc39 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove unused Composite casting logic

上级 6bf3878d
......@@ -4300,14 +4300,8 @@ class Composite(ScalarInnerGraphOp):
self._fgraph = fgraph
return self._fgraph
def clone_float32(self):
# This will not modify the fgraph or the nodes
new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs)
def clone(self):
new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs)
return Composite(self.fgraph.inputs, self.fgraph.outputs)
def output_types(self, input_types):
if tuple(input_types) != self.inputs_type:
......@@ -4423,86 +4417,4 @@ class Composite(ScalarInnerGraphOp):
return self.c_code_template % d
def c_code_cache_version_outer(self) -> tuple[int, ...]:
return (6,)
class Compositef32:
# This is a dict of scalar op classes that need special handling
special: dict = {}
def apply(self, fgraph):
mapping = {}
topo = fgraph.toposort()
for i in fgraph.inputs:
if i.dtype == "float16":
mapping[i] = get_scalar_type("float32")()
if hasattr(i.tag, "test_value"):
mapping[i].tag.test_value = i.tag.test_value
else:
mapping[i] = i
for node in topo:
# Patch up for constants
for i in node.inputs:
if i not in mapping:
assert type(i) is ScalarConstant
if i.type == float16:
ni = ScalarConstant(float32, i.data)
else:
ni = i
mapping[i] = ni
if isinstance(node.op, tuple(self.special)):
self.special[type(node.op)](node, mapping)
continue
new_node = node.clone_with_new_inputs(
[mapping[inp] for inp in node.inputs], strict=False
)
# make sure we don't produce any float16.
assert not any(o.dtype == "float16" for o in new_node.outputs)
mapping.update(zip(node.outputs, new_node.outputs, strict=True))
new_ins = [mapping[inp] for inp in fgraph.inputs]
new_outs = [mapping[out] for out in fgraph.outputs]
return new_ins, new_outs
composite_f32 = Compositef32()
def handle_cast(node, mapping):
inp = mapping[node.inputs[0]]
out = node.outputs[0]
node_ok = False
if node.op.o_type == float16:
if node.inputs[0].type == float32:
# cast f32 -> f16, remove
mapping[out] = inp
return
else:
# cast to f16, convert to f32
new_out = cast(inp, "float32")
# change the node for the following if
node = new_out.owner
mapping[out] = new_out
node_ok = True
if node.inputs[0].type == float16:
if node.op.o_type == inp.type:
# cast f16 to new input type, remove
mapping[out] = inp
return
if not node_ok:
new_node = node.clone_with_new_inputs([inp], strict=False)
mapping[out] = new_node.outputs[0]
Compositef32.special[Cast] = handle_cast
def handle_composite(node, mapping):
new_op = node.op.clone_float32()
new_outs = new_op(*[mapping[i] for i in node.inputs], return_list=True)
assert len(new_outs) == len(node.outputs)
for o, no in zip(node.outputs, new_outs, strict=True):
mapping[o] = no
Compositef32.special[Composite] = handle_composite
return (7,)
......@@ -23,7 +23,6 @@ from pytensor.scalar.basic import (
arctan,
arctan2,
arctanh,
cast,
complex64,
constant,
cos,
......@@ -33,7 +32,6 @@ from pytensor.scalar.basic import (
exp,
exp2,
expm1,
float16,
float32,
floats,
int8,
......@@ -53,11 +51,9 @@ from pytensor.scalar.basic import (
sin,
sinh,
sqrt,
switch,
tan,
tanh,
true_div,
uint8,
)
from pytensor.tensor.type import fscalar, imatrix, matrix
from tests.link.test_link import make_function
......@@ -72,43 +68,6 @@ def test_mul_add_true():
class TestComposite:
def test_composite_clone_float32(self):
def has_f16(comp):
if any(v.type == float16 for v in comp.fgraph.variables):
return True
return False
w = int8()
x = float16()
y = float32()
cz = Composite([x, y], [tanh(x + cast(y, "float16"))])
c = Composite(
[w, x, y],
[
cz(x, y)
- cz(x, y) ** 2
+ cast(x, "int16")
+ cast(x, "float32")
+ cast(w, "float16")
- constant(np.float16(1.0))
],
)
assert has_f16(c)
nc = c.clone_float32()
assert not has_f16(nc)
v = uint8()
w = float16()
x = float16()
y = float16()
z = float16()
c = Composite([v, w, x, y, z], [switch(v, mul(w, x, y), z)])
assert has_f16(c)
nc = c.clone_float32()
assert not has_f16(nc)
def test_straightforward(self):
x, y, _z = floats("xyz")
e = mul(add(x, y), true_div(x, y))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论