提交 1dabbb54 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

figure out dtype automatically in tensor_random

上级 47fedf92
...@@ -128,8 +128,12 @@ class NumpyGenerator(gof.op.Op): ...@@ -128,8 +128,12 @@ class NumpyGenerator(gof.op.Op):
shape = tensor.convert_to_int64(_shape) shape = tensor.convert_to_int64(_shape)
if shape.type.ndim != 1: if shape.type.ndim != 1:
raise TypeError('shape argument was not converted to 1-d tensor', _shape) raise TypeError('shape argument was not converted to 1-d tensor', _shape)
# we generate one random number with the distribution to determine what dtype to expect
output_dtype = str(self.fn(numpy.random.RandomState(18), size=(1,)).dtype)
inputs = [gof.Value(gof.type.generic, numpy.random.RandomState(self.seed)), shape] inputs = [gof.Value(gof.type.generic, numpy.random.RandomState(self.seed)), shape]
outputs = [tensor.Tensor(dtype='float64', broadcastable = [False]*self.ndim).make_result()] outputs = [tensor.Tensor(dtype=output_dtype, broadcastable = [False]*self.ndim).make_result()]
return gof.Apply(op = self, inputs = inputs, outputs = outputs) return gof.Apply(op = self, inputs = inputs, outputs = outputs)
def grad(self, inputs, grad_outputs): def grad(self, inputs, grad_outputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论