提交 0bb33d65 authored 作者: Tanjay94's avatar Tanjay94

Fixed infer_shape.

上级 95701b81
......@@ -471,29 +471,16 @@ Op Example
from theano.compile.ops import as_op
from theano.compile.ops import FromFunctionOp
def infer_shape_numpy_dot(node, input_shapes):
ashp, bshp = input_shapes
return [ashp[:-1] + bshp[-1:]]
@as_op(itypes=[theano.tensor.fmatrix, theano.tensor.fmatrix],
otypes=[theano.tensor.fmatrix])
otypes=[theano.tensor.fmatrix], infer_shape=infer_shape_numpy_dot)
def numpy_dot(a, b):
return numpy.dot(a, b)
def infer_shape(self, node, shapes):
xshp, yshp = shapes
x, y = node.inputs
# vector / vector
if x.ndim == 1 and y.ndim == 1:
return [()]
# matrix / vector
if x.ndim == 2 and y.ndim == 1:
return [xshp[:-1]]
# vector / matrix
if x.ndim == 1 and y.ndim == 2:
return [yshp[-1:]]
# matrix / matrix
if x.ndim == 2 and y.ndim == 2:
return [xshp[:-1] + yshp[-1:]]
raise NotImplementedError()
You can try it as follows:
.. code-block:: python
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论