提交 c6fc7c59 authored 作者: James Bergstra's avatar James Bergstra

Added AdvancedSubtensor.infer_shape

上级 dcbc3e43
......@@ -3296,6 +3296,9 @@ class AdvancedSubtensor(Op):
% (x.ndim, ','.join(str(input.ndim) for input in inputs)))
raise NotImplementedError('Advanced indexing of x with arguments (%s) not supported yet'\
% ','.join(str(input) for input in inputs))
def infer_shape(self, node, ishapes):
xshp, ind1shp, ind2shp = ishapes
return [ind2shp]
def perform(self, node, inputs, (out,)):
# TODO: in general, we need to re-pack the inputs into a valid index, just like
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论