提交 a1d277bc authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Allow NumPy-only arguments to theano.tensor.dot

上级 6d39faaf
...@@ -71,6 +71,7 @@ from tests.tensor.utils import ( ...@@ -71,6 +71,7 @@ from tests.tensor.utils import (
from theano import change_flags, compile, config, function, gof, shared from theano import change_flags, compile, config, function, gof, shared
from theano.compile import DeepCopyOp from theano.compile import DeepCopyOp
from theano.compile.mode import get_default_mode from theano.compile.mode import get_default_mode
from theano.gof.graph import Variable
from theano.scalar import autocast_float, autocast_float_as from theano.scalar import autocast_float, autocast_float_as
from theano.tensor import ( from theano.tensor import (
Alloc, Alloc,
...@@ -3648,6 +3649,15 @@ class TestMatinv: ...@@ -3648,6 +3649,15 @@ class TestMatinv:
assert_almost_equal(ssd, myssd) assert_almost_equal(ssd, myssd)
def test_dot_numpy_inputs():
"""Test the `theano.tensor.dot` interface function with NumPy inputs."""
a = np.ones(2)
b = np.ones(2)
res = tt.dot(a, b)
assert isinstance(res, Variable)
assert isinstance(res.owner.op, Dot)
class TestDot: class TestDot:
def setup_method(self): def setup_method(self):
utt.seed_rng() utt.seed_rng()
......
...@@ -6311,6 +6311,13 @@ def dot(l, r): ...@@ -6311,6 +6311,13 @@ def dot(l, r):
This is designed to work with both sparse and dense tensors types. This is designed to work with both sparse and dense tensors types.
""" """
if not isinstance(l, Variable):
l = as_tensor_variable(l)
if not isinstance(r, Variable):
r = as_tensor_variable(r)
try: try:
res = l.__dot__(r) res = l.__dot__(r)
if res is NotImplemented: if res is NotImplemented:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论