提交 a43d461f authored 作者: Amjad Almahairi's avatar Amjad Almahairi

add support for old interface

上级 e1f691c3
...@@ -35,7 +35,7 @@ class MultinomialFromUniform(Op): ...@@ -35,7 +35,7 @@ class MultinomialFromUniform(Op):
except AttributeError: except AttributeError:
self.odtype = 'auto' self.odtype = 'auto'
def make_node(self, pvals, unis, n): def make_node(self, pvals, unis, n=1):
pvals = T.as_tensor_variable(pvals) pvals = T.as_tensor_variable(pvals)
unis = T.as_tensor_variable(unis) unis = T.as_tensor_variable(unis)
if pvals.ndim != 2: if pvals.ndim != 2:
...@@ -151,7 +151,12 @@ class MultinomialFromUniform(Op): ...@@ -151,7 +151,12 @@ class MultinomialFromUniform(Op):
""" % locals() """ % locals()
def perform(self, node, ins, outs): def perform(self, node, ins, outs):
(pvals, unis, n_samples) = ins # support old pickled graphs
if len(ins) == 2:
(pvals, unis) = ins
n_samples = 1
else:
(pvals, unis, n_samples) = ins
(z,) = outs (z,) = outs
if unis.shape[0] != pvals.shape[0] * n_samples: if unis.shape[0] != pvals.shape[0] * n_samples:
...@@ -264,7 +269,13 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp): ...@@ -264,7 +269,13 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
""" % locals() """ % locals()
def c_code(self, node, name, ins, outs, sub): def c_code(self, node, name, ins, outs, sub):
(pvals, unis) = ins # support old pickled graphs
if len(ins) == 2:
(pvals, unis) = ins
n_samples = 1
else:
(pvals, unis, n_samples) = ins
(z,) = outs (z,) = outs
fail = sub['fail'] fail = sub['fail']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论