#!/usr/bin/env python

# Authors: Ronald L. Sprouse (ronald@berkeley.edu)
# 
# Copyright (c) 2017-2018, The Regents of the University of California
# All rights reserved.
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
# 
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# 
# * Neither the name of the University of California nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
# 
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os, sys
import re
import getopt
import subprocess
import audiolabel

standard_usage_str = '''
multi_align [--tiers tierstring] [--dict-codec codec] [--out-codec codec]
[--in-codec codec] [--no-check-dict] [--input-type type] [--input textgrid]
[--output textgrid] [--align-cmd aligner_path] [--text2tg textgrid]
wavfile
'''
help_usage_str = '''multi_align --help|-h'''

def usage():
    print('\n' + standard_usage_str)
    print('\n' + help_usage_str + '\n')

def help():
    print(standard_usage_str)
    print('\n' + help_usage_str + '\n')
    print('''
multi_align: Iteratively call aligner for every non-empty label in one
or more textgrid tiers and compile the result into a single textgrid.

Before running the aligner, the input textgrid is evaluated against the
aligner's main dictionary and the dict.local file, if it exists. If all
of the words to be aligned are found in the combined dictionary, then
the alignment will run as normal. If any input words are not in the
combined dictionary, then alignment will not occur and the missing
words will be written to STDERR, one word per line, and multi_align
will exit with an error code.
'''
    )
    print('''
Required argument:

    wavfile
    The name of the .wav file to be aligned.

Optional parameters:

    --tiers tierstring
    A comma-separated list of tiers containing labels to align. Do not
    include spaces. Tiers are identified by name.

    If the --tiers parameter is not used, then all tiers will be aligned.

    To align specific tiers to stereo or multi-channel audio, specify channel
    assignments by putting a colon and then the channel number after the tier
    name.

        Examples:

        # Align labels in tier named 'speaker'
        --tiers speaker

        # Align labels in tiers named 'speaker1' and 'speaker2'
        --tiers speaker1,speaker2

        # Align labels in tier 'speaker' to channel 1 and in tier 'speaker2'
          to channel 2
        --tiers speaker1,speaker2:2

    If you use the '--input-type raw|text' parameter, a single tier named 'speaker'
    is available for mapping to an input audio channel.

    --dict-codec codec
    Codec from the Python codec module used for decoding of the aligner's
    dictionary. Use this parameter if the encoding used by your dictionary
    file (normally dict.local) does not match the encoding used by your
    input textgrid, and multi_align will attempt to recode into the
    dictionary encoding. Default is to match the encoding of the input
    textgrid.

    --out-codec codec
    Codec from the Python codec module used for encoding of output textgrid.
    Default is to match the encoding of the input textgrid.

    --in-codec codec
    Codec from the Python codec module used for decoding of input textgrid.
    Normally textgrid encoding can be detected automatically. Use this parameter
    if automatic detection fails.

    --no-check-dict
    When this parameter is included the aligner will always run without
    checking that all of the words in the input textgrid are found in the
    dictionary.

    --input-type textgrid|text|raw (default: textgrid)
    If included, this parameter indicates the kind of input transcript
    file. The default is 'textgrid' for a Praat textgrid. The other options
    'text' and 'raw' select a simple text input with no time indicators; the
    utterance defined by the text will be aligned to the entire audio file.
    If 'text' is selected, then the value of the --input parameter names a
    text file containing the utterance. If 'raw' is selected, then the value
    of --input is the text to be aligned.

    If the --tiers argument is used with --input-type text|raw, then
    multi_align will process the simple text input as if the input had been
    provided with a single annotation tier with the name provided by --tiers.
    This behavior can be useful to map the text input to a single channel of
    a stereo audio file, and it also causes the tier name to propagate to
    output files.

    See also the --input and --text2tg parameters for interactions with
    this parameter.

    --input transcript_file|str
    The name of the input textgrid with labels to align. If this parameter is
    not provided, then the script will look for a file with the same name as
    wavfile and with the extension '.TextGrid', for example, 'myfile.TextGrid'
    for 'myfile.wav'. If '--input-type text' was also used, then the
    default input filename has the extension '.txt' instead of '.TextGrid'.

    If '--input-type raw' was also used, then the value of --input is the
    utterance transcript to align against the input audio file.

    --output textgrid
    By default the input textgrid name forms the basis of the output
    textgrid's name, which will have '.multi_align' inserted immediately
    before the extension. For example, the output file for an input
    textgrid named 'myfile.TextGrid' is 'myfile.multi_align.TextGrid'.
    Use --output to specify an alternative name.

    --align-cmd aligner_path
    The name of the aligner executable to be invoked as the first argument
    when creating the aligner process via Python's subprocess module.
    If the path is not provided as part of the name, the executable must be
    found in the environment $PATH. Default is 'pyalign'.

    --text2tg textgrid
    If included, this parameter provides the name of a textgrid that is
    created from a simple text input and suitable for use as an input
    textgrid on a future invocation of multi_align. This textgrid contains
    a single tier (named 'speaker') that contains a single label. The label
    contains the input transcript and extends over the entire length of
    the audio file.

    This parameter requires '--input-type text|raw' and throws an error if
    --input-type is not 'text' or 'raw'.

Output file:

    The output file is a textgrid with two tiers (word and phone) for each input
    tier.

    If only one tier is aligned, then the output tiers are named 'word' and
    'phone', as they normally are named by pyalign.

    If multiple tiers are aligned, the output tiers are named by concatenating
    the suffixes '_word' and '_phone' to the input tier names, e.g.
    'speaker1_word' and 'speaker1_phone' for input tier 'speaker1'.

    See also --text2tg for a secondary output file that can be created.
'''
    )

def duration(audiofile):
    '''Return the duration of an audio file.'''
    dur = subprocess.check_output(['sox', '--info', '-D', audiofile]).strip()
    return float(dur)

def find_missing_dictwords(tiers, pattern, main_dict, local_dict, dictcodec):
    '''Search tiers for labels matching pattern and determine if there are any
words in those labels that are not in the dictionary. Return a list of missing
words, or None if no words are missing.'''
    nonempty = re.compile(r'\S')
    mwords = []
    lwords = []
    try:
        with open(main_dict, 'r', encoding=dictcodec) as mdict:
            mwords = [line.split()[0] for line in mdict if nonempty.search(line)]
    except IOError:
        pass
    try:
        with open(local_dict, 'r', encoding=dictcodec) as ldict:
            lwords = [line.split()[0] for line in ldict if nonempty.search(line)]
    except IOError:
        pass
    dictwords = mwords + lwords

    # Prep text to be aligned as in prep_mlf in align.py.
    missing = []
    # this pattern matches hyphenated words, such as TWENTY-TWO; however,
    # it doesn't work with longer things like SOMETHING-OR-OTHER
    hyphenPat = re.compile(r'([A-Z]+)-([A-Z]+)')
    punclist = [',', '.', ':', ';', '!', '?', '"', '%', '(', ')', '--', '---']
    for tier in tiers:
        for lab in tier.search(pattern):
            txt = lab.text.replace('\n', '')
            txt = txt.replace('{breath}', '{BR}').replace('<noise>', '{NS}')
            txt = txt.replace('{laugh}', '{LG}').replace('{laughter}', '{LG}')
            txt = txt.replace('{cough}', '{CG}').replace('{lipsmack}', '{LS}')

            for pun in punclist:
                txt = txt.replace(pun,  '')
            txt = txt.upper()
            # break up any hyphenated words into two separate words
            txt = re.sub(hyphenPat, r'\1 \2', txt)

            for wrd in txt.split():
                if (not wrd in dictwords) and (not wrd in missing):
                    missing.append(wrd)
    if missing == []:
        missing = None
    return missing

def append_labels(labels, tier):
    '''Append a series of labels to a tier and insert empty label to fill gap.
Note that the tier is passed by reference and altered in place.'''
    try:
        end = tier[-1].t2    # Get end time of last label in tier.
    except IndexError:       # Tier does not contain any labels yet.
        end = tier.start
    assert(end <= labels[0].t1)
    if end < labels[0].t1:   # Fill gap if needed.
        tier.add(audiolabel.Label(text='', t1=end, t2=labels[0].t1))
    for l in labels:
        tier.add(l)

def multi_align_tier(tier, wav, channel, pattern=None, dictcodec=None,
        align_cmd='pyalign'):
    '''Align all the non-empty labels in a tier on wav. Return the output
phone and word tiers.'''
    temp_tg = 'temp_textgrid.TextGrid'
    temp_txt = 'temp_transcript.txt'
    pt = audiolabel.IntervalTier(
        name = 'phone',
        start=tier.start,
        end=tier.end
    )
    wt = audiolabel.IntervalTier(
        name = 'word',
        start=tier.start,
        end=tier.end
    )

    # Align each non-empty label and add result to word/phone tiers.
    for lab in tier.search(pattern):
        if dictcodec is None:
            with open(temp_txt, 'w') as f:
                f.write(lab.text)
        else:
            with open(temp_txt, 'wb') as f:
                f.write(lab.text.encode(dictcodec))
        args = [
            align_cmd,
            '-s', str(lab.t1),
            '-e', str(lab.t2),
            '-c', str(channel),
            wav,
            temp_txt,
            temp_tg
        ]
        try:
            subprocess.check_call(args)
            tempm = audiolabel.LabelManager(
                from_file=temp_tg, from_type='praat'
            )
            append_labels([l for l in tempm.tier('word')], wt)
            append_labels([l for l in tempm.tier('phone')], pt)
        except subprocess.CalledProcessError as e:
            sys.stderr.write(
                'Error aligning word at t1 {:0.4f} from tier {:}'.format(
                     lab.t1, tier.name
                )
            )
            errlab = audiolabel.Label(
                text='**ERROR** ' + str(e), t1=lab.t1, t2=lab.t2
            )
            append_labels([errlab], wt)
            append_labels([errlab], pt)
    # Add empty labels at end of tier, if needed.
    if pt[-1].t2 < pt.end:
        pt.add(audiolabel.Label(text='', t1=pt[-1].t2, t2=pt.end))
    if wt[-1].t2 < wt.end:
        wt.add(audiolabel.Label(text='', t1=wt[-1].t2, t2=wt.end))
    return (pt, wt)

if __name__ == '__main__':
    try:
        opts, args = getopt.getopt(
            sys.argv[1:],
            "h",
            ["tiers=", "dict-codec=", "out-codec=", "in-codec=", "input=",
             "output=", "no-check-dict", "input-type=", "text2tg=", "align-cmd=",
             "help"]
        )
    except getopt.GetoptError as err:
        sys.stderr.write(str(err))
        usage()
        sys.exit(2)
    wav = None
    textgrid = None
    outfile = None
    pattern = r'\S'
    main_dict = '/opt/p2fa/model/dict'
    local_dict = 'dict.local'
    check_dict = True
    tiernames = []
    dictcodec = None
    outcodec = None
    incodec = None
    align_cmd = 'pyalign'
    is_text_input = False
    text2tg_file = None
    pat = re.compile(r'(.+):(\d+)$')
    for o, a in opts:
        if o in ("--tiers"):
            # Add channel 1 default where not explicitly provided by user.
            tfull = [
                i + ':1' if pat.search(i) is None else i for i in a.split(',')
            ]
            # Create dict with tier names as keys and channel as values.
            tiersdict = dict([
                (m.group(1), m.group(2)) for i in tfull for m in [pat.search(i)]
            ])
            tiernames = tiersdict.keys()
        elif o in ("--dict-codec"):
            dictcodec = a
        elif o in ("--out-codec"):
            outcodec = a
        elif o in ("--in-codec"):
            incodec = a
        elif o in ("--input"):
            textgrid = a
        elif o in ("--output"):
            outfile = a
        elif o in ("--no-check-dict"):
            check_dict = False
        elif o in ("--input-type"):
            is_text_input = (a == 'text') or (a == 'raw')
        elif o in ("--text2tg"):
            text2tg_file = a
        elif o in ("--align-cmd"):
            align_cmd = a
        elif o in ("-h", "--help"):
            help()
            sys.exit(0)
    try:
        wav = args[0]
    except:
        print('\n[ERROR]: input audio file not specified')
        usage()
        sys.exit(2)
    if textgrid is None:
        bname, ext = os.path.splitext(wav)
        textgrid = bname + '.TextGrid'
    if outfile is None:
        bname, ext = os.path.splitext(textgrid)
        if is_text_input is False:
            outfile = bname + '.multi_align' + ext
        else:
            outfile = bname + '.multi_align.TextGrid'
    if text2tg_file is not None and is_text_input is False:
        print('\n[ERROR]: --text2tg requires --input-type text|raw')
        usage()
        sys.exit(2)

    if is_text_input is False:
        inm = audiolabel.LabelManager(
            from_file=textgrid,
            from_type='praat',
            codec=incodec
        )
    else:
        if incodec is None:
            incodec = 'utf-8'
        dur = duration(wav)
        tier = audiolabel.IntervalTier(end=dur, name='speaker')
        with open(textgrid, 'r', encoding=incodec) as infile:
            transcript = infile.read()
        tier.add(
            audiolabel.Label(text=transcript, t1=0.0, t2=dur, codec=incodec)
        )
        inm = audiolabel.LabelManager(
            codec=incodec
        )
        inm.add(tier)
    if dictcodec is None:
        dictcodec = inm.codec
    if outcodec is None:
        outcodec = inm.codec
    outm = audiolabel.LabelManager(codec=outcodec)
    if tiernames == []:   # Default to all tiers
        tiers = [t for t in inm]
    else:
        tiers = [inm.tier(name) for name in tiernames]

    if check_dict is True:
        missing = find_missing_dictwords(
            tiers, pattern, main_dict, local_dict, dictcodec
        )
        if missing is not None:
            sys.stderr.write('Missing words found:\n')
            for word in missing:
                sys.stderr.write('{:}\n'.format(word))
            sys.stderr.write(
                'Quitting. Use --no-check-dict to skip the dictionary check.\n'
            )
            sys.exit(3)

    for tier in tiers:
        try:
            channel = tiersdict[tier.name]
        except NameError: # no tiersdict because no --tiers arg
            channel = 1
        phonetier, wordtier = multi_align_tier(
            tier,
            wav,
            channel,
            pattern=pattern,
            dictcodec=dictcodec,
            align_cmd=align_cmd
        )
        if len(tiers) > 1:
            phonetier.name = '{:}_phone'.format(tier.name)
            wordtier.name = '{:}_word'.format(tier.name)
        outm.add(phonetier)
        outm.add(wordtier)
    with open(outfile,'wb') as outf:
        outf.write(outm.as_string('praat_short').encode(outcodec))

    if text2tg_file is not None:
        with open(text2tg_file,'wb') as outf:
            outf.write(inm.as_string('praat_short').encode(outcodec))