提交 788e6612 authored 作者: Frederic's avatar Frederic

Small fixes.

上级 304c9e03
...@@ -10,6 +10,9 @@ import theano ...@@ -10,6 +10,9 @@ import theano
from theano import gof from theano import gof
import numpy
def register_view_op_c_code(type, code, version=()): def register_view_op_c_code(type, code, version=()):
""" Tell ViewOp how to generate C code for a Theano Type """ Tell ViewOp how to generate C code for a Theano Type
...@@ -71,7 +74,7 @@ class ViewOp(gof.Op): ...@@ -71,7 +74,7 @@ class ViewOp(gof.Op):
if not v: if not v:
warnings.warn("Type %s has C code for ViewOp, but it has " warnings.warn("Type %s has C code for ViewOp, but it has "
"no version. You should add a 'version' keyword arg " "no version. You should add a 'version' keyword arg "
"when calling register_deep_copy_op_c_code." % t, "when calling register_view_op_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -161,7 +164,7 @@ class DeepCopyOp(gof.Op): ...@@ -161,7 +164,7 @@ class DeepCopyOp(gof.Op):
if not v: if not v:
warnings.warn("Type %s has C code for DeepCopyOp, but it has " warnings.warn("Type %s has C code for DeepCopyOp, but it has "
"no version. You should add a 'version' keyword arg " "no version. You should add a 'version' keyword arg "
"when calling register_OutputGuard_c_code." % t, "when calling register_deep_copy_op_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -266,7 +269,19 @@ class Shape(gof.Op): ...@@ -266,7 +269,19 @@ class Shape(gof.Op):
return super(Shape, self).c_code(node, name, inames, onames, sub) return super(Shape, self).c_code(node, name, inames, onames, sub)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) version = []
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), key=lambda pair: str(pair[0])):
if not v:
warnings.warn("Type %s has C code for Shape, but it has "
"no version. You should add a 'version' keyword arg "
"when calling register_shape_c_code." % t,
stacklevel=2)
return ()
version.append((str(t), v))
return tuple(version)
shape = Shape() shape = Shape()
...@@ -324,7 +339,7 @@ class Shape_i(gof.Op): ...@@ -324,7 +339,7 @@ class Shape_i(gof.Op):
if not v: if not v:
warnings.warn("Type %s has C code for Shape_i, but it has " warnings.warn("Type %s has C code for Shape_i, but it has "
"no version. You should add a 'version' keyword arg " "no version. You should add a 'version' keyword arg "
"when calling register_OutputGuard_c_code." % t, "when calling register_shape_i_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -548,7 +563,7 @@ class Rebroadcast(gof.Op): ...@@ -548,7 +563,7 @@ class Rebroadcast(gof.Op):
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
assert len(ishapes) == 1 assert len(ishapes) == 1
l = [] l = []
one = constant(1) one = theano.tensor.basic.constant(1)
for ax in xrange(len(ishapes[0])): for ax in xrange(len(ishapes[0])):
if self.axis.get(ax, False): if self.axis.get(ax, False):
l.append(one) l.append(one)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论