ort_session = onnxruntime.InferenceSession("vits.onnx",  providers=['CUDAExecutionProvider'])
print('onnx InferenceSession ok.')

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name:to_numpy(x_tst), ort_session.get_inputs()[1].name:to_numpy(x_tst_lengths)}
print(datetime.datetime.now())
ort_outs = ort_session.run(None, ort_inputs)
print(datetime.datetime.now())
print(type(ort_outs))

ort_outs = np.array(ort_outs)
print(type(ort_outs))
print(ort_outs)
ort_audio = ort_outs[0,0,0]
print(ort_audio)

print(datetime.datetime.now())

save_wav(ort_audio, "./onnx_baker_long.wav", 16000)