提交 5111df45 authored 作者: gdesjardins's avatar gdesjardins

merge

......@@ -364,6 +364,12 @@ class Constant(Value):
if self.name is not None:
return self.name
return str(self.data) #+ "::" + str(self.type)
def clone(self):
"""
We clone this object, but we don't clone the data to lower memory requirement
We suppose that the data will never change.
"""
return self.__class__(self.type, self.data, self.name)
......
......@@ -3712,6 +3712,11 @@ pprint.assign(dot, printing.OperatorPrinter(printing.special['middle_dot'], -1,
#########################
class TensorDotGrad(Op):
def __init__(self, axes):
if isinstance(axes,list):
for i,a in enumerate(axes):
if isinstance(a,list):
axes[i]=tuple(a)
axes=tuple(axes)
self.axes = axes;
def __eq__(self, other):
......@@ -3759,6 +3764,11 @@ class TensorDot(Op):
"""
def __init__(self, axes):
if isinstance(axes,list):
for i,a in enumerate(axes):
if isinstance(a,list):
axes[i]=tuple(a)
axes=tuple(axes)
self.axes = axes;
def __eq__(self, other):
......
......@@ -2444,6 +2444,20 @@ class test_tensordot(unittest.TestCase):
f6(bval,aval)))
utt.verify_grad(TensorDot(axes), [bval,aval])
def test_list(self):
# test matrix-matrix
amat = dmatrix()
bmat = dmatrix()
axes = [[1,],[0,]]
c = tensordot(axes)(amat, bmat)
f3 = inplace_func([amat,bmat],c)
aval = numpy.random.rand(4,7);
bval = numpy.random.rand(7,9);
self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \
f3(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval])
def test_smallest_stack():
sx, sy = dscalar(), dscalar()
......
......@@ -1015,15 +1015,15 @@ def test_log1p():
# the first three ops are Shape_i, Shape_i, and Dimshuffle
theano.printing.debugprint(f)
assert [node.op for node in f.maker.env.toposort()][3:] \
== [inplace.log1p_inplace, alloc]
== [T.log1p, alloc]
f = function([x,y], T.log(0+(x) + fill(y,1.0)), mode=m)
theano.printing.debugprint(f)
assert [node.op for node in f.maker.env.toposort()][3:] \
== [inplace.log1p_inplace, alloc]
== [T.log1p, alloc]
f = function([x,y], T.log(2+(x) - fill(y,1.0)), mode=m)
theano.printing.debugprint(f)
assert [node.op for node in f.maker.env.toposort()][3:] \
== [inplace.log1p_inplace, alloc]
== [T.log1p, alloc]
f([1e-7, 10], [[0, 0], [0, 0]]) #debugmode will verify values
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论