提交 d93f815a authored 作者: lamblin's avatar lamblin

Merge pull request #670 from nouiz/fix_new_op

Fix new op
...@@ -5,6 +5,7 @@ from copy import copy ...@@ -5,6 +5,7 @@ from copy import copy
import graph import graph
import utils import utils
import toolbox import toolbox
from python25 import all
from theano import config from theano import config
...@@ -144,6 +145,18 @@ class Env(utils.object2): ...@@ -144,6 +145,18 @@ class Env(utils.object2):
# sets up node so it belongs to this env # sets up node so it belongs to this env
if hasattr(node, 'env') and node.env is not self: if hasattr(node, 'env') and node.env is not self:
raise Exception("%s is already owned by another env" % node) raise Exception("%s is already owned by another env" % node)
if (hasattr(node.op, 'view_map') and
not all([isinstance(view, (list, tuple))
for view in node.op.view_map.values()])):
raise Exception("Op '%s' have a bad view map '%s',"
" the values must be tuples or lists." % (
str(node.op), str(node.op.view_map)))
if (hasattr(node.op, 'destroy_map') and
not all([isinstance(destroy, (list, tuple))
for destroy in node.op.destroy_map.values()])):
raise Exception("Op '%s' have a bad destroy map '%s',"
" the values must be tuples or lists." % (
str(node.op), str(node.op.destroy_map)))
node.env = self node.env = self
node.deps = {} node.deps = {}
#self.execute_callbacks('on_setup_node', node) #self.execute_callbacks('on_setup_node', node)
......
...@@ -21,6 +21,10 @@ class DiffOp(theano.Op): ...@@ -21,6 +21,10 @@ class DiffOp(theano.Op):
def __init__(self, n=1, axis=-1): def __init__(self, n=1, axis=-1):
self.n = n self.n = n
self.axis = axis self.axis = axis
# numpy return a view in that case.
# TODO, make an optimization that remove this op in this case.
if n == 0:
self.view_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (type(self) == type(other) and
......
...@@ -91,7 +91,7 @@ class TestDiffOp(utt.InferShapeTester): ...@@ -91,7 +91,7 @@ class TestDiffOp(utt.InferShapeTester):
def test_grad(self): def test_grad(self):
x = T.vector('x') x = T.vector('x')
a = np.random.random(500) a = np.random.random(50)
gf = theano.function([x], T.grad(T.sum(diff(x)), x)) gf = theano.function([x], T.grad(T.sum(diff(x)), x))
utt.verify_grad(self.op, [a]) utt.verify_grad(self.op, [a])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论