提交 269903aa authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use TensorType.ndim instead of TensorVariable.ndim in shape operations

This change makes the respective operations work with non-`TensorVariable` classes.
上级 7bc40c67
......@@ -39,8 +39,8 @@ def infer_shape(outs, inputs, input_shapes):
# let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually
for inp, inp_shp in zip(inputs, input_shapes):
if inp_shp is not None and len(inp_shp) != inp.ndim:
assert len(inp_shp) == inp.ndim
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim
shape_feature = ShapeFeature()
shape_feature.on_attach(FunctionGraph([], []))
......
......@@ -909,7 +909,7 @@ class ShapeFeature(features.Feature):
node = var.owner
# recur on inputs
for i in node.inputs:
if getattr(i, "ndim", None) > 0:
if getattr(i.type, "ndim", None) > 0:
self.get_shape(i, 0)
o_shapes = self.get_node_infer_shape(node)
assert len(o_shapes) == len(node.outputs)
......@@ -917,12 +917,12 @@ class ShapeFeature(features.Feature):
# Only change the variables and dimensions that would introduce
# extra computation
for new_shps, out in zip(o_shapes, node.outputs):
if not hasattr(out, "ndim"):
if not hasattr(out.type, "ndim"):
continue
merged_shps = list(self.shape_of[out])
changed = False
for i in range(out.ndim):
for i in range(out.type.ndim):
n_r = merged_shps[i]
if (
n_r.owner
......@@ -951,10 +951,10 @@ class ShapeFeature(features.Feature):
def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r."""
if not hasattr(r, "ndim"):
if not hasattr(r.type, "ndim"):
# This happen for NoneConst.
return None
return tuple([self.shape_ir(i, r) for i in range(r.ndim)])
return tuple(self.shape_ir(i, r) for i in range(r.type.ndim))
def default_infer_shape(self, fgraph, node, i_shapes):
"""Return a list of shape tuple or None for the outputs of node.
......@@ -1020,7 +1020,7 @@ class ShapeFeature(features.Feature):
and s_i.owner.inputs[0].owner
and isinstance(s_i.owner.inputs[0].owner.op, Shape)
):
assert s_i.ndim == 0
assert s_i.type.ndim == 0
assert len(s_i.owner.op.idx_list) == 1
# The current Subtensor always put constant index in the graph.
......@@ -1068,32 +1068,28 @@ class ShapeFeature(features.Feature):
if not isinstance(s, (tuple, list)):
raise TypeError("shapes must be tuple/list", (r, s))
if r.ndim != len(s):
if r.type.ndim != len(s):
sio = StringIO()
aesara.printing.debugprint(r, file=sio, print_type=True)
raise AssertionError(
f"Something inferred a shape with {len(s)} dimensions "
f"for a variable with {int(r.ndim)} dimensions"
f"for a variable with {int(r.type.ndim)} dimensions"
f" for the variable:\n{sio.getvalue()}"
)
shape_vars = []
for i in range(r.ndim):
for i in range(r.type.ndim):
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
shape_vars.append(self.lscalar_one)
else:
shape_vars.append(self.unpack(s[i], r))
assert all(
[
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
or
# The two following comparison are a speed optimization
# But we never timed this speed optimization!
self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i]))
for i in range(r.ndim)
]
not hasattr(r.type, "broadcastable") or not r.type.broadcastable[i] or
# The two following comparison are a speed optimization
# But we never timed this speed optimization!
self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i]))
for i in range(r.type.ndim)
)
self.shape_of[r] = tuple(shape_vars)
for sv in shape_vars:
......@@ -1171,21 +1167,19 @@ class ShapeFeature(features.Feature):
else:
merged_shape.append(other_shape[i])
assert all(
[
(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
and not other_r.type.broadcastable[i]
)
or
# The two following comparison are a speed optimization
# But we never timed this speed optimization!
self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(
extract_constant(merged_shape[i], only_process_constants=True)
)
for i in range(r.ndim)
]
(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
and not other_r.type.broadcastable[i]
)
or
# The two following comparison are a speed optimization
# But we never timed this speed optimization!
self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(
extract_constant(merged_shape[i], only_process_constants=True)
)
for i in range(r.type.ndim)
)
self.shape_of[r] = tuple(merged_shape)
for sv in self.shape_of[r]:
......@@ -1204,14 +1198,12 @@ class ShapeFeature(features.Feature):
else:
new_shape.append(s_j)
assert all(
[
not hasattr(r.type, "broadcastable") or not r.type.broadcastable[idx] or
# The two following comparison are a speed optimization
# But we never timed this speed optimization!
self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx]))
for idx in range(r.ndim)
]
not hasattr(r.type, "broadcastable") or not r.type.broadcastable[idx] or
# The two following comparison are a speed optimization
# But we never timed this speed optimization!
self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx]))
for idx in range(r.type.ndim)
)
self.shape_of[r] = tuple(new_shape)
for sv in self.shape_of[r]:
......
......@@ -63,7 +63,7 @@ class Shape(COp):
x = at.as_tensor_variable(x)
if isinstance(x.type, TensorType):
out_var = TensorType("int64", (x.ndim,))()
out_var = TensorType("int64", (x.type.ndim,))()
else:
out_var = aesara.tensor.type.lvector()
......@@ -164,7 +164,9 @@ def shape_tuple(x: Variable) -> Tuple[Variable]:
one_at = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)
return tuple(
one_at if getattr(sh, "value", sh) == 1 or bcast else sh
for sh, bcast in zip(shape(x), getattr(x, "broadcastable", (False,) * x.ndim))
for sh, bcast in zip(
shape(x), getattr(x, "broadcastable", (False,) * x.type.ndim)
)
)
......@@ -214,9 +216,11 @@ class Shape_i(COp):
return "%s{%i}" % (self.__class__.__name__, self.i)
def make_node(self, x):
if not isinstance(x, Variable):
raise TypeError(f"{x} must be Variable with ndim attribute")
if x.ndim <= self.i:
if not isinstance(x, Variable) or not hasattr(x.type, "ndim"):
raise TypeError(
f"{x} must be `Variable` with a `Type` having an ndim attribute"
)
if x.type.ndim <= self.i:
raise TypeError(f"{x} has too few dimensions for Shape_i")
return Apply(self, [x], [aesara.tensor.type.lscalar()])
......@@ -421,9 +425,9 @@ class SpecifyShape(COp):
if any(s.dtype not in aesara.tensor.type.integer_dtypes for s in shape):
raise TypeError("Shape values must be integer types")
if len(shape) != x.ndim:
if len(shape) != x.type.ndim:
raise ValueError(
f"Input `x` is {x.ndim}-dimensional and will never match a shape of length {len(shape)}."
f"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}."
)
if isinstance(x.type, TensorType) and all(isinstance(s, Number) for s in shape):
......@@ -451,7 +455,7 @@ class SpecifyShape(COp):
def infer_shape(self, fgraph, node, shapes):
xshape, sshape = shapes
new_shape = []
for dim in range(node.inputs[0].ndim):
for dim in range(node.inputs[0].type.ndim):
try:
s = at.get_scalar_constant_value(node.inputs[1][dim])
s = at.as_tensor_variable(s)
......
......@@ -2987,6 +2987,8 @@ def test_local_Shape_i_of_broadcastable():
# A test for a non-`TensorType`
class MyType(Type):
ndim = 1
def filter(self, *args, **kwargs):
raise NotImplementedError()
......@@ -2994,7 +2996,7 @@ def test_local_Shape_i_of_broadcastable():
return isinstance(other, MyType) and other.thingy == self.thingy
class MyVariable(Variable):
ndim = 1
pass
x = MyVariable(MyType(), None, None)
s = Shape_i(0)(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论