提交 269903aa authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use TensorType.ndim instead of TensorVariable.ndim in shape operations

This change makes the respective operations work with non-`TensorVariable` classes.
上级 7bc40c67
...@@ -39,8 +39,8 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -39,8 +39,8 @@ def infer_shape(outs, inputs, input_shapes):
# let it initialize itself with an empty fgraph, otherwise we will # let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually # need to do it manually
for inp, inp_shp in zip(inputs, input_shapes): for inp, inp_shp in zip(inputs, input_shapes):
if inp_shp is not None and len(inp_shp) != inp.ndim: if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.ndim assert len(inp_shp) == inp.type.ndim
shape_feature = ShapeFeature() shape_feature = ShapeFeature()
shape_feature.on_attach(FunctionGraph([], [])) shape_feature.on_attach(FunctionGraph([], []))
......
...@@ -909,7 +909,7 @@ class ShapeFeature(features.Feature): ...@@ -909,7 +909,7 @@ class ShapeFeature(features.Feature):
node = var.owner node = var.owner
# recur on inputs # recur on inputs
for i in node.inputs: for i in node.inputs:
if getattr(i, "ndim", None) > 0: if getattr(i.type, "ndim", None) > 0:
self.get_shape(i, 0) self.get_shape(i, 0)
o_shapes = self.get_node_infer_shape(node) o_shapes = self.get_node_infer_shape(node)
assert len(o_shapes) == len(node.outputs) assert len(o_shapes) == len(node.outputs)
...@@ -917,12 +917,12 @@ class ShapeFeature(features.Feature): ...@@ -917,12 +917,12 @@ class ShapeFeature(features.Feature):
# Only change the variables and dimensions that would introduce # Only change the variables and dimensions that would introduce
# extra computation # extra computation
for new_shps, out in zip(o_shapes, node.outputs): for new_shps, out in zip(o_shapes, node.outputs):
if not hasattr(out, "ndim"): if not hasattr(out.type, "ndim"):
continue continue
merged_shps = list(self.shape_of[out]) merged_shps = list(self.shape_of[out])
changed = False changed = False
for i in range(out.ndim): for i in range(out.type.ndim):
n_r = merged_shps[i] n_r = merged_shps[i]
if ( if (
n_r.owner n_r.owner
...@@ -951,10 +951,10 @@ class ShapeFeature(features.Feature): ...@@ -951,10 +951,10 @@ class ShapeFeature(features.Feature):
def shape_tuple(self, r): def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r.""" """Return a tuple of symbolic shape vars for tensor variable r."""
if not hasattr(r, "ndim"): if not hasattr(r.type, "ndim"):
# This happen for NoneConst. # This happen for NoneConst.
return None return None
return tuple([self.shape_ir(i, r) for i in range(r.ndim)]) return tuple(self.shape_ir(i, r) for i in range(r.type.ndim))
def default_infer_shape(self, fgraph, node, i_shapes): def default_infer_shape(self, fgraph, node, i_shapes):
"""Return a list of shape tuple or None for the outputs of node. """Return a list of shape tuple or None for the outputs of node.
...@@ -1020,7 +1020,7 @@ class ShapeFeature(features.Feature): ...@@ -1020,7 +1020,7 @@ class ShapeFeature(features.Feature):
and s_i.owner.inputs[0].owner and s_i.owner.inputs[0].owner
and isinstance(s_i.owner.inputs[0].owner.op, Shape) and isinstance(s_i.owner.inputs[0].owner.op, Shape)
): ):
assert s_i.ndim == 0 assert s_i.type.ndim == 0
assert len(s_i.owner.op.idx_list) == 1 assert len(s_i.owner.op.idx_list) == 1
# The current Subtensor always put constant index in the graph. # The current Subtensor always put constant index in the graph.
...@@ -1068,32 +1068,28 @@ class ShapeFeature(features.Feature): ...@@ -1068,32 +1068,28 @@ class ShapeFeature(features.Feature):
if not isinstance(s, (tuple, list)): if not isinstance(s, (tuple, list)):
raise TypeError("shapes must be tuple/list", (r, s)) raise TypeError("shapes must be tuple/list", (r, s))
if r.ndim != len(s): if r.type.ndim != len(s):
sio = StringIO() sio = StringIO()
aesara.printing.debugprint(r, file=sio, print_type=True) aesara.printing.debugprint(r, file=sio, print_type=True)
raise AssertionError( raise AssertionError(
f"Something inferred a shape with {len(s)} dimensions " f"Something inferred a shape with {len(s)} dimensions "
f"for a variable with {int(r.ndim)} dimensions" f"for a variable with {int(r.type.ndim)} dimensions"
f" for the variable:\n{sio.getvalue()}" f" for the variable:\n{sio.getvalue()}"
) )
shape_vars = [] shape_vars = []
for i in range(r.ndim): for i in range(r.type.ndim):
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]: if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
shape_vars.append(self.lscalar_one) shape_vars.append(self.lscalar_one)
else: else:
shape_vars.append(self.unpack(s[i], r)) shape_vars.append(self.unpack(s[i], r))
assert all( assert all(
[ not hasattr(r.type, "broadcastable") or not r.type.broadcastable[i] or
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
or
# The two following comparison are a speed optimization # The two following comparison are a speed optimization
# But we never timed this speed optimization! # But we never timed this speed optimization!
self.lscalar_one.equals(shape_vars[i]) self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i])) or self.lscalar_one.equals(extract_constant(shape_vars[i]))
for i in range(r.ndim) for i in range(r.type.ndim)
]
) )
self.shape_of[r] = tuple(shape_vars) self.shape_of[r] = tuple(shape_vars)
for sv in shape_vars: for sv in shape_vars:
...@@ -1171,7 +1167,6 @@ class ShapeFeature(features.Feature): ...@@ -1171,7 +1167,6 @@ class ShapeFeature(features.Feature):
else: else:
merged_shape.append(other_shape[i]) merged_shape.append(other_shape[i])
assert all( assert all(
[
( (
not hasattr(r.type, "broadcastable") not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i] or not r.type.broadcastable[i]
...@@ -1184,8 +1179,7 @@ class ShapeFeature(features.Feature): ...@@ -1184,8 +1179,7 @@ class ShapeFeature(features.Feature):
or self.lscalar_one.equals( or self.lscalar_one.equals(
extract_constant(merged_shape[i], only_process_constants=True) extract_constant(merged_shape[i], only_process_constants=True)
) )
for i in range(r.ndim) for i in range(r.type.ndim)
]
) )
self.shape_of[r] = tuple(merged_shape) self.shape_of[r] = tuple(merged_shape)
for sv in self.shape_of[r]: for sv in self.shape_of[r]:
...@@ -1204,14 +1198,12 @@ class ShapeFeature(features.Feature): ...@@ -1204,14 +1198,12 @@ class ShapeFeature(features.Feature):
else: else:
new_shape.append(s_j) new_shape.append(s_j)
assert all( assert all(
[
not hasattr(r.type, "broadcastable") or not r.type.broadcastable[idx] or not hasattr(r.type, "broadcastable") or not r.type.broadcastable[idx] or
# The two following comparison are a speed optimization # The two following comparison are a speed optimization
# But we never timed this speed optimization! # But we never timed this speed optimization!
self.lscalar_one.equals(new_shape[idx]) self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx])) or self.lscalar_one.equals(extract_constant(new_shape[idx]))
for idx in range(r.ndim) for idx in range(r.type.ndim)
]
) )
self.shape_of[r] = tuple(new_shape) self.shape_of[r] = tuple(new_shape)
for sv in self.shape_of[r]: for sv in self.shape_of[r]:
......
...@@ -63,7 +63,7 @@ class Shape(COp): ...@@ -63,7 +63,7 @@ class Shape(COp):
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x)
if isinstance(x.type, TensorType): if isinstance(x.type, TensorType):
out_var = TensorType("int64", (x.ndim,))() out_var = TensorType("int64", (x.type.ndim,))()
else: else:
out_var = aesara.tensor.type.lvector() out_var = aesara.tensor.type.lvector()
...@@ -164,7 +164,9 @@ def shape_tuple(x: Variable) -> Tuple[Variable]: ...@@ -164,7 +164,9 @@ def shape_tuple(x: Variable) -> Tuple[Variable]:
one_at = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1) one_at = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)
return tuple( return tuple(
one_at if getattr(sh, "value", sh) == 1 or bcast else sh one_at if getattr(sh, "value", sh) == 1 or bcast else sh
for sh, bcast in zip(shape(x), getattr(x, "broadcastable", (False,) * x.ndim)) for sh, bcast in zip(
shape(x), getattr(x, "broadcastable", (False,) * x.type.ndim)
)
) )
...@@ -214,9 +216,11 @@ class Shape_i(COp): ...@@ -214,9 +216,11 @@ class Shape_i(COp):
return "%s{%i}" % (self.__class__.__name__, self.i) return "%s{%i}" % (self.__class__.__name__, self.i)
def make_node(self, x): def make_node(self, x):
if not isinstance(x, Variable): if not isinstance(x, Variable) or not hasattr(x.type, "ndim"):
raise TypeError(f"{x} must be Variable with ndim attribute") raise TypeError(
if x.ndim <= self.i: f"{x} must be `Variable` with a `Type` having an ndim attribute"
)
if x.type.ndim <= self.i:
raise TypeError(f"{x} has too few dimensions for Shape_i") raise TypeError(f"{x} has too few dimensions for Shape_i")
return Apply(self, [x], [aesara.tensor.type.lscalar()]) return Apply(self, [x], [aesara.tensor.type.lscalar()])
...@@ -421,9 +425,9 @@ class SpecifyShape(COp): ...@@ -421,9 +425,9 @@ class SpecifyShape(COp):
if any(s.dtype not in aesara.tensor.type.integer_dtypes for s in shape): if any(s.dtype not in aesara.tensor.type.integer_dtypes for s in shape):
raise TypeError("Shape values must be integer types") raise TypeError("Shape values must be integer types")
if len(shape) != x.ndim: if len(shape) != x.type.ndim:
raise ValueError( raise ValueError(
f"Input `x` is {x.ndim}-dimensional and will never match a shape of length {len(shape)}." f"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}."
) )
if isinstance(x.type, TensorType) and all(isinstance(s, Number) for s in shape): if isinstance(x.type, TensorType) and all(isinstance(s, Number) for s in shape):
...@@ -451,7 +455,7 @@ class SpecifyShape(COp): ...@@ -451,7 +455,7 @@ class SpecifyShape(COp):
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
xshape, sshape = shapes xshape, sshape = shapes
new_shape = [] new_shape = []
for dim in range(node.inputs[0].ndim): for dim in range(node.inputs[0].type.ndim):
try: try:
s = at.get_scalar_constant_value(node.inputs[1][dim]) s = at.get_scalar_constant_value(node.inputs[1][dim])
s = at.as_tensor_variable(s) s = at.as_tensor_variable(s)
......
...@@ -2987,6 +2987,8 @@ def test_local_Shape_i_of_broadcastable(): ...@@ -2987,6 +2987,8 @@ def test_local_Shape_i_of_broadcastable():
# A test for a non-`TensorType` # A test for a non-`TensorType`
class MyType(Type): class MyType(Type):
ndim = 1
def filter(self, *args, **kwargs): def filter(self, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
...@@ -2994,7 +2996,7 @@ def test_local_Shape_i_of_broadcastable(): ...@@ -2994,7 +2996,7 @@ def test_local_Shape_i_of_broadcastable():
return isinstance(other, MyType) and other.thingy == self.thingy return isinstance(other, MyType) and other.thingy == self.thingy
class MyVariable(Variable): class MyVariable(Variable):
ndim = 1 pass
x = MyVariable(MyType(), None, None) x = MyVariable(MyType(), None, None)
s = Shape_i(0)(x) s = Shape_i(0)(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论