# This config contains the default values for training FastPitch model with aligner using 22KHz sampling 
# rate. If you want to train model on other dataset, you can change config values according to your dataset. 
# Most dataset-specific arguments are in the head of the config file, see below.

name: FastPitch

train_dataset: ???
validation_datasets: ???
sup_data_path: ???
sup_data_types: [ "align_prior_matrix", "pitch" ]

# Default values from librosa.pyin
pitch_fmin: 65.40639132514966
pitch_fmax: 2093.004522404789

# these frame-wise values depend on pitch_fmin and pitch_fmax, you can get values
# by running `scripts/dataset_processing/tts/extract_sup_data.py`
pitch_mean: ???  # e.g. 132.524658203125 for https://zenodo.org/record/5525342/files/thorsten-neutral_v03.tgz?download=1
pitch_std: ???   # e.g.  37.389366149902 for https://zenodo.org/record/5525342/files/thorsten-neutral_v03.tgz?download=1

# Default values for dataset with sample_rate=22050
sample_rate: 22050
n_mel_channels: 80
n_window_size: 1024
n_window_stride: 256
n_fft: 1024
lowfreq: 0
highfreq: null
window: hann

model:
  learn_alignment: true
  bin_loss_warmup_epochs: 100

  n_speakers: 1
  max_token_duration: 75
  symbols_embedding_dim: 384
  pitch_embedding_kernel_size: 3

  pitch_fmin: ${pitch_fmin}
  pitch_fmax: ${pitch_fmax}

  pitch_mean: ${pitch_mean}
  pitch_std: ${pitch_std}

  sample_rate: ${sample_rate}
  n_mel_channels: ${n_mel_channels}
  n_window_size: ${n_window_size}
  n_window_stride: ${n_window_stride}
  n_fft: ${n_fft}
  lowfreq: ${lowfreq}
  highfreq: ${highfreq}
  window: ${window}

  text_normalizer:
    _target_: nemo_text_processing.text_normalization.normalize.Normalizer
    lang: de
    input_case: cased

  text_normalizer_call_kwargs:
    verbose: false
    punct_pre_process: true
    punct_post_process: true

  text_tokenizer:
    _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.GermanCharsTokenizer
    punct: true
    apostrophe: true
    pad_with_space: true

  train_ds:
    dataset:
      _target_: nemo.collections.tts.data.dataset.TTSDataset
      manifest_filepath: ${train_dataset}
      sample_rate: ${model.sample_rate}
      sup_data_path: ${sup_data_path}
      sup_data_types: ${sup_data_types}
      n_fft: ${model.n_fft}
      win_length: ${model.n_window_size}
      hop_length: ${model.n_window_stride}
      window: ${model.window}
      n_mels: ${model.n_mel_channels}
      lowfreq: ${model.lowfreq}
      highfreq: ${model.highfreq}
      max_duration: 15  # change to null to include longer audios.
      min_duration: 0.1
      ignore_file: null
      trim: true
      trim_top_db: 50
      trim_frame_length: ${model.n_window_size}
      trim_hop_length: ${model.n_window_stride}
      pitch_fmin: ${model.pitch_fmin}
      pitch_fmax: ${model.pitch_fmax}
      pitch_norm: true
      pitch_mean: ${model.pitch_mean}
      pitch_std: ${model.pitch_std}

    dataloader_params:
      drop_last: false
      shuffle: true
      batch_size: 32
      num_workers: 12
      pin_memory: true

  validation_ds:
    dataset:
      _target_: nemo.collections.tts.data.dataset.TTSDataset
      manifest_filepath: ${validation_datasets}
      sample_rate: ${model.sample_rate}
      sup_data_path: ${sup_data_path}
      sup_data_types: ${sup_data_types}
      n_fft: ${model.n_fft}
      win_length: ${model.n_window_size}
      hop_length: ${model.n_window_stride}
      window: ${model.window}
      n_mels: ${model.n_mel_channels}
      lowfreq: ${model.lowfreq}
      highfreq: ${model.highfreq}
      max_duration: 15  # change to null to include longer audios.
      min_duration: 0.1
      ignore_file: null
      trim: true
      trim_top_db: 50
      trim_frame_length: ${model.n_window_size}
      trim_hop_length: ${model.n_window_stride}
      pitch_fmin: ${model.pitch_fmin}
      pitch_fmax: ${model.pitch_fmax}
      pitch_norm: true
      pitch_mean: ${model.pitch_mean}
      pitch_std: ${model.pitch_std}

    dataloader_params:
      drop_last: false
      shuffle: false
      batch_size: 32
      num_workers: 8
      pin_memory: true

  preprocessor:
    _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
    features: ${model.n_mel_channels}
    lowfreq: ${model.lowfreq}
    highfreq: ${model.highfreq}
    n_fft: ${model.n_fft}
    n_window_size: ${model.n_window_size}
    window_size: false
    n_window_stride: ${model.n_window_stride}
    window_stride: false
    pad_to: 1
    pad_value: 0
    sample_rate: ${model.sample_rate}
    window: ${model.window}
    normalize: null
    preemph: null
    dither: 0.0
    frame_splicing: 1
    log: true
    log_zero_guard_type: add
    log_zero_guard_value: 1e-05
    mag_power: 1.0

  input_fft: #n_embed and padding_idx are added by the model
    _target_: nemo.collections.tts.modules.transformer.FFTransformerEncoder
    n_layer: 6
    n_head: 1
    d_model: ${model.symbols_embedding_dim}
    d_head: 64
    d_inner: 1536
    kernel_size: 3
    dropout: 0.1
    dropatt: 0.1
    dropemb: 0.0
    d_embed: ${model.symbols_embedding_dim}

  output_fft:
    _target_: nemo.collections.tts.modules.transformer.FFTransformerDecoder
    n_layer: 6
    n_head: 1
    d_model: ${model.symbols_embedding_dim}
    d_head: 64
    d_inner: 1536
    kernel_size: 3
    dropout: 0.1
    dropatt: 0.1
    dropemb: 0.0

  alignment_module:
    _target_: nemo.collections.tts.modules.aligner.AlignmentEncoder
    n_text_channels: ${model.symbols_embedding_dim}

  duration_predictor:
    _target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor
    input_size: ${model.symbols_embedding_dim}
    kernel_size: 3
    filter_size: 256
    dropout: 0.1
    n_layers: 2

  pitch_predictor:
    _target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor
    input_size: ${model.symbols_embedding_dim}
    kernel_size: 3
    filter_size: 256
    dropout: 0.1
    n_layers: 2

  optim:
    name: adamw
    lr: 1e-3
    betas: [ 0.9, 0.999 ]
    weight_decay: 1e-6

    sched:
      name: NoamAnnealing
      warmup_steps: 1000
      last_epoch: -1
      d_model: 1 # Disable scaling based on model dim

trainer:
  num_nodes: 1
  devices: -1  # specify all GPUs regardless of its availability
  accelerator: gpu
  strategy: ddp
  precision: 16
  max_epochs: 1500
  accumulate_grad_batches: 1
  gradient_clip_val: 1000.0
  enable_checkpointing: false  # Provided by exp_manager
  logger: false  # Provided by exp_manager
  log_every_n_steps: 100
  check_val_every_n_epoch: 5
  benchmark: false

exp_manager:
  exp_dir: null
  name: ${name}
  create_tensorboard_logger: true
  create_checkpoint_callback: true
  checkpoint_callback_params:
    monitor: val_loss
  resume_if_exists: false
  resume_ignore_no_checkpoint: false