提交 788e6612 authored 作者: Frederic's avatar Frederic

Small fixes.

上级 304c9e03
......@@ -10,6 +10,9 @@ import theano
from theano import gof
import numpy
def register_view_op_c_code(type, code, version=()):
""" Tell ViewOp how to generate C code for a Theano Type
......@@ -71,7 +74,7 @@ class ViewOp(gof.Op):
if not v:
warnings.warn("Type %s has C code for ViewOp, but it has "
"no version. You should add a 'version' keyword arg "
"when calling register_deep_copy_op_c_code." % t,
"when calling register_view_op_c_code." % t,
stacklevel=2)
return ()
version.append((str(t), v))
......@@ -161,7 +164,7 @@ class DeepCopyOp(gof.Op):
if not v:
warnings.warn("Type %s has C code for DeepCopyOp, but it has "
"no version. You should add a 'version' keyword arg "
"when calling register_OutputGuard_c_code." % t,
"when calling register_deep_copy_op_c_code." % t,
stacklevel=2)
return ()
version.append((str(t), v))
......@@ -266,7 +269,19 @@ class Shape(gof.Op):
return super(Shape, self).c_code(node, name, inames, onames, sub)
def c_code_cache_version(self):
return (1,)
version = []
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), key=lambda pair: str(pair[0])):
if not v:
warnings.warn("Type %s has C code for Shape, but it has "
"no version. You should add a 'version' keyword arg "
"when calling register_shape_c_code." % t,
stacklevel=2)
return ()
version.append((str(t), v))
return tuple(version)
shape = Shape()
......@@ -324,7 +339,7 @@ class Shape_i(gof.Op):
if not v:
warnings.warn("Type %s has C code for Shape_i, but it has "
"no version. You should add a 'version' keyword arg "
"when calling register_OutputGuard_c_code." % t,
"when calling register_shape_i_c_code." % t,
stacklevel=2)
return ()
version.append((str(t), v))
......@@ -548,7 +563,7 @@ class Rebroadcast(gof.Op):
def infer_shape(self, node, ishapes):
assert len(ishapes) == 1
l = []
one = constant(1)
one = theano.tensor.basic.constant(1)
for ax in xrange(len(ishapes[0])):
if self.axis.get(ax, False):
l.append(one)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论