import argparse
import json
import logging
import math
import multiprocessing
import os
import sys
from multiprocessing import cpu_count

import ijson
import soundfile
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from tqdm import tqdm

sys.path.insert(0, sys.path[0] + "/../")
from utils.binary import DatasetWriter

logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--wenetspeech_json', type=str, default='/media/WenetSpeech数据集/WenetSpeech.json',
                    help="WenetSpeech的标注json文件路径")
parser.add_argument('--add_pun', type=bool, default=True, help="是否添加标点符")
parser.add_argument('--annotation_dir', type=str, default='../dataset/', help="存放数据列表的文件夹路径")
args = parser.parse_args()

if not os.path.exists(args.annotation_dir):
    os.makedirs(args.annotation_dir)

# 训练、测试数据列表
train_list_path = os.path.join(args.annotation_dir, 'train_wenet.json')
test_net_path = os.path.join(args.annotation_dir, 'test_net.json')
test_meeting_path = os.path.join(args.annotation_dir, 'test_meeting.json')


# 获取标注信息
def get_data(wenetspeech_json):
    data_list = []
    input_dir = os.path.dirname(wenetspeech_json)
    i = 0
    # 开始读取数据,因为文件太大,无法获取进度
    with open(wenetspeech_json, 'r', encoding='utf-8') as f:
        objects = ijson.items(f, 'audios.item')
        print("开始读取数据")
        while True:
            try:
                long_audio = objects.__next__()
                i += 1
                try:
                    long_audio_path = os.path.realpath(os.path.join(input_dir, long_audio['path']))
                    aid = long_audio['aid']
                    segments_lists = long_audio['segments']
                    assert (os.path.exists(long_audio_path))
                except AssertionError:
                    print(f'''Warning: {long_audio_path} 不存在或者已经处理过自动删除了,跳过''')
                    continue
                except Exception:
                    print(f'''Warning: {aid} 数据读取错误,跳过''')
                    continue
                else:
                    data_list.append([long_audio_path.replace('\\', '/'), segments_lists])
            except StopIteration:
                print("数据读取完成")
                break
    return data_list


def main():
    f_train = open(train_list_path, 'w', encoding='utf-8')
    f_test_net = open(test_net_path, 'w', encoding='utf-8')
    f_test_meeting = open(test_meeting_path, 'w', encoding='utf-8')

    all_data = get_data(args.wenetspeech_json)
    print(f'总数据量为:{len(all_data)}')
    for data in tqdm(all_data):
        long_audio_path, segments_lists = data
        for segment_file in segments_lists:
            start_time = float(segment_file['begin_time'])
            end_time = float(segment_file['end_time'])
            text = segment_file['text']
            confidence = segment_file['confidence']
            if confidence < 0.95: continue
            line = dict(audio={"path": long_audio_path,
                               "start_time": round(start_time, 3),
                               "end_time": round(end_time, 3)},
                        sentence=text,
                        duration=round(end_time - start_time, 3))
            data_type = long_audio_path.split('/')[-4]
            if data_type == 'test_net':
                f_test_net.write(json.dumps(line, ensure_ascii=False) + '\n')
            if data_type == 'test_meeting':
                f_test_meeting.write(json.dumps(line, ensure_ascii=False) + '\n')
            if data_type == 'train':
                f_train.write(json.dumps(line, ensure_ascii=False) + '\n')
    f_train.close()
    f_test_meeting.close()
    f_test_net.close()


# 合并多条音频,增加时间戳,同时加速训练
def merge_list():
    for file_path in [train_list_path, test_net_path, test_meeting_path]:
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        with open(file_path, 'w', encoding='utf-8') as f:
            sentences = []
            duration = 0
            start_time = 0
            text = ''
            for i in tqdm(range(len(lines))):
                data = json.loads(lines[i])
                sentence = data["sentence"]
                # 新数据
                if duration == 0:
                    start_time = data['audio']["start_time"]
                duration = data['audio']["end_time"] - start_time
                # 带时间戳数据
                sentences.append({"start": round(data['audio']["start_time"] - start_time, 2),
                                  "end": round(data['audio']['end_time'] - start_time, 2),
                                  "text": sentence})
                text += sentence
                name = data['audio']['path']
                if i < len(lines) - 2:
                    next_data = json.loads(lines[i + 1])
                    next_name = next_data['audio']['path']
                    next_end_time = next_data['audio']["end_time"]
                    # 如果下一条数据是新数据或者加上就大于30秒,就写入数据
                    if next_name != name or next_end_time - start_time >= 30:
                        data1 = dict()
                        data1['audio'] = {"path": data['audio']['path']}
                        data1['audio']['start_time'] = start_time
                        data1['audio']['end_time'] = data['audio']['end_time']
                        data1['duration'] = round(data['audio']['end_time'] - start_time, 2)
                        data1['sentence'] = text
                        data1['sentences'] = sentences
                        f.write(f'{json.dumps(data1, ensure_ascii=False)}\n')
                        sentences = []
                        duration = 0
                        start_time = 0
                        text = ''
                else:
                    # 最后一条数据处理方式
                    data1 = dict()
                    data1['audio'] = {"path": data['audio']['path']}
                    data1['audio']['start_time'] = start_time
                    data1['audio']['end_time'] = data['audio']['end_time']
                    data1['duration'] = round(data['audio']['end_time'] - start_time, 2)
                    data1['sentence'] = text
                    data1['sentences'] = sentences
                    f.write(f'{json.dumps(data1, ensure_ascii=False)}\n')
                    sentences = []
                    duration = 0
                    start_time = 0
                    text = ''


# 设置空白音频和转换格式
def process_audio(data, i):
    for path, sentences in tqdm(data, desc=f"处理进程{i}"):
        if not os.path.exists(path): continue
        save_path = path[:-5] + '.flac'
        if os.path.exists(save_path): continue
        sample, sr = soundfile.read(path)
        for sentence in sentences:
            start, end = sentence
            start = max(int((start + 0.1) * sr), 0)
            end = min(int((end - 0.1) * sr), len(sample))
            sample[start:end] = 0
        soundfile.write(save_path, sample, sr)


# 设置没有标注的位置静音
def set_silence():
    for file_path in [train_list_path, test_net_path]:
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        all_data = {}
        for line in tqdm(lines, desc='读取数据列表'):
            data = json.loads(line)
            path = data['audio']['path']
            if os.path.splitext(path)[-1] != '.opus': continue
            start_a = data['audio']['start_time']
            sentences = data['sentences']
            last_end = start_a
            for sentence in sentences:
                start = round(start_a + sentence['start'], 3)
                if start - last_end > 1:
                    if path in all_data.keys():
                        all_data[path].append([last_end, start])
                    else:
                        all_data[path] = [[last_end, start]]
                else:
                    if path not in all_data.keys():
                        all_data[path] = []
                last_end = round(start_a + sentence['end'], 3)
        # 多进程处理数据
        all_data = list(all_data.items())
        num_worker = cpu_count()
        length = math.ceil(len(all_data) / num_worker)
        data = [all_data[i * length:(i + 1) * length] for i in range(num_worker)]
        my_process = []
        for i in range(num_worker):
            process = multiprocessing.Process(target=process_audio, args=(data[i], i))
            my_process.append(process)
        for process in my_process:
            process.start()
        for process in my_process:
            process.join()
        # 修改路径,因为是转成flac了
        with open(file_path, 'w', encoding='utf-8') as f:
            for line in tqdm(lines, desc='修改路径后缀'):
                data = json.loads(line)
                path = data['audio']['path']
                path = path.replace('.opus', '.flac')
                if not os.path.exists(path):
                    print(f'{path}文件不存在', file=sys.stderr)
                    continue
                data['audio']['path'] = path
                f.write(json.dumps(data, ensure_ascii=False) + '\n')


# 添加标点符号
def process_pun(data, i):
    inference_pipline = pipeline(task=Tasks.punctuation,
                                 model='damo/punc_ct-transformer_cn-en-common-vocab471067-large',
                                 model_revision="v1.0.0")
    f = open(f'temp{i}.txt', 'w', encoding='utf-8')
    for line in tqdm(data, desc=f"处理进程{i}"):
        data = json.loads(line)
        sentence = data['sentence']
        sentence = sentence.replace(',', '').replace('。', '').replace('?', '').replace('!', '').replace('、', '')
        sentence = inference_pipline(text_in=sentence)['text']
        data['sentence'] = sentence

        param_dict = {"cache": []}
        sentences = data['sentences']
        for i in range(len(sentences)):
            text = sentences[i]['text']
            text = text.replace(',', '').replace('。', '').replace('?', '').replace('!', '').replace('、', '')
            text = inference_pipline(text_in=text, param_dict=param_dict)['text']
            sentences[i]['text'] = text
        f.write(json.dumps(data, ensure_ascii=False) + '\n')


# 多进程添加标点符号
def add_pun():
    for file_path in [train_list_path, test_net_path, test_meeting_path]:
        with open(file_path, 'r', encoding='utf-8') as f:
            all_data = f.readlines()
        # 多进程添加标点符号,根据自己的显存大小调整
        num_worker = 4
        length = math.ceil(len(all_data) / num_worker)
        data = [all_data[i * length:(i + 1) * length] for i in range(num_worker)]
        my_process = []
        for i in range(num_worker):
            process = multiprocessing.Process(target=process_pun, args=(data[i], i))
            my_process.append(process)
        for process in my_process:
            process.start()
        for process in my_process:
            process.join()
        # 合并文件
        with open(file_path, 'w', encoding='utf-8') as fw:
            for i in range(num_worker):
                with open(f'temp{i}.txt', 'r', encoding='utf-8') as fr:
                    lines = fr.readlines()
                for line in lines:
                    fw.write(line)


# 转成二进制文件,减少内存占用
def create_binary():
    print('正在把数据列表转成二进制文件...')
    dataset_writer = DatasetWriter(f"{args.annotation_dir}/train")
    with open(train_list_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    for line in tqdm(lines):
        line = line.replace('\n', '')
        dataset_writer.add_data(line)
    dataset_writer.close()


if __name__ == '__main__':
    main()
    # 合并多条音频,增加时间戳,同时加速训练
    merge_list()
    # 设置没有标注的位置静音
    set_silence()
    # 添加标点符号
    if args.add_pun:
        add_pun()
    # 转成二进制文件,减少内存占用
    create_binary()