#!/usr/bin/env python
# coding=UTF-8
# CWA Dump File Recovery

import sys
import os
from struct import *

def checksum512(data):
  sum = 0
  for o in range(256):
    value = data[2 * o] | (data[2 * o + 1] << 8)
    sum = (sum + value) & 0xffff
  return sum

def singleBit(value):
  if value != 0 and not (value & (value - 1)):
    return value
  else:
    return 0

def _fast_timestamp(value):
    """Faster date/time parsing.  This does not include 'always' limits; invalid dates do not cause an error; the first call will be slower as a lookup table is created."""
    # On first run, build lookup table for initial 10-bits of the packed date-time parsing, minus one day as days are 1-indexed
    if not hasattr(_fast_timestamp, "SECONDS_BEFORE_YEAR_MONTH"):
        _fast_timestamp.SECONDS_BEFORE_YEAR_MONTH = [0] * 1024        # YYYYYYMM MM (Y=years since 2000, M=month-of-year 1-indexed)
        SECONDS_PER_DAY = 24 * 60 * 60
        DAYS_IN_MONTH = [ 0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31, 0, 0, 0 ]    # invalid month 0, months 1-12 (non-leap-year), invalid months 13-15
        seconds_before = 946684800      # Seconds from UNIX epoch (1970) until device epoch (2000)
        for year in range(0, 64):       # 2000-2063
            for month in range(0, 16):  # invalid month 0, months 1-12, invalid months 13-15
                index = (year << 4) + month
                _fast_timestamp.SECONDS_BEFORE_YEAR_MONTH[index] = seconds_before - SECONDS_PER_DAY    # minus one day as day-of-month is 1-based
                days = DAYS_IN_MONTH[month]
                if year % 4 == 0 and month == 2:    # Correct for this year range (2000 was a leap year, despite being a multiple of 100, as it is a multiple of 400)
                    days += 1
                seconds_before += days * SECONDS_PER_DAY

    year_month = (value >> 22) & 0x3ff
    day   = (value >> 17) & 0x1f
    hours = (value >> 12) & 0x1f
    mins  = (value >>  6) & 0x3f
    secs  = value & 0x3f
    return _fast_timestamp.SECONDS_BEFORE_YEAR_MONTH[year_month] + ((day * 24 + hours) * 60 + mins) * 60 + secs


def recoverCwa(inputFile, outputFile, method, modifyFlags):
  initialOffset = 0 # 0x20000 in drive dumps
  sectorSize = 512
  headerSize = 2 * sectorSize
  paddingMax = 128  # Allow first few bytes to be garbage
  # allowedPadding = [ 0xe0, 0xf0 ] # Allowed garbage bytes
  globalSessionId = None
  globalDeviceId = None

  idxTimestamp = method.index('t')
  idxSession = method.index('s')
  idxSequenceId = method.index('q')

  reorder = True
  if method.find('+') != -1:
    reorder = False
  
  print("Reading input: ", inputFile)
  
  with open(inputFile, 'rb') as fi:
    fi.seek(0, os.SEEK_END) # 2
    fileSize = fi.tell() # 992161*512
    fi.seek(0, os.SEEK_SET) # 0
    fileData = fi.read(fileSize)
  
  # Store found data sectors
  metadata = [] # (offset,size,sessionId)
  data = []     # (offset,size,sessionId,sequenceId,
  
  countCorrectedSession = 0
  countCorrectedSequence = 0
  countChecksumErrors = 0
  maxChecksumErrors = 10000
  
  numSectors = fileSize // sectorSize
  print("Processing " + str(numSectors) + " sectors...")
  lastPerc = -1
  for i in range(numSectors):
    block = fileData[i * sectorSize:(i + 1) * sectorSize]
    perc = (100 * i) // numSectors
    if 5 * (perc // 5) != lastPerc:
      print("..." + str(i) + "/" + str(numSectors) + " = " + str(perc) + "%...")
      lastPerc = perc
      
    # Skip if 0xff
    if block[0] == 255:
      continue
    
    for o in range(paddingMax):
      ofs = i * sectorSize + o
      # Check if this is a data block
      if block[o] == ord('A') and block[o + 1] == ord('X') and block[o + 2] == 0xfc and block[o+3] == 0x01:
        completeBlock = (o == 0)
        fileOffset = i * sectorSize + o
        blockLength = sectorSize - o
        sessionId = unpack('<I', block[o+6:o+10])[0]
        
        rateCode = unpack('B', block[24:25])[0]                   # @24  +1   Sample rate code, frequency (3200/(1<<(15-(rate & 0x0f)))) Hz, range (+/-g) (16 >> (rate >> 6)).
        frequency = 3200 / (1 << (15 - (rateCode & 0x0f)))        
        
        timestamp = _fast_timestamp(unpack('<I', block[o+14:o+18])[0])
        deviceFractional = (timestamp << 16) + ((unpack('<H', block[o+4:o+6])[0] & 0x8fff) << 1)
        timestampOffset = unpack('<h', block[26:28])[0]           # @26  +2   Relative sample index from the start of the buffer where the whole-second timestamp is valid
        
        timeFractional = 0;
        # Need to undo backwards-compatible shim by calculating how many whole samples the fractional part of timestamp accounts for.
        timeFractional = (deviceFractional & 0x7fff) << 1     # use original deviceId field bottom 15-bits as 16-bit fractional time
        timestampOffset += (timeFractional * int(frequency)) >> 16 # undo the backwards-compatible shift (as we have a true fractional)

        # Add fractional time to timestamp
        timestamp += timeFractional / 65536
                    
        sequenceId = unpack('<I', block[o+10:o+14])[0]
        
        # Report on mismatched checksums
        checksumOk = False
        if completeBlock:
          checksum = unpack('<H', block[o+510:o+512])[0]
          actualChecksum = checksum512(block)
          if actualChecksum == 0:
            checksumOk = True

        if not checksumOk:
          countChecksumErrors += 1
          if countChecksumErrors < maxChecksumErrors:
            print("Mismatched checksum #" + str(countChecksumErrors) + " at " + str(i) + " = " + str(actualChecksum) + " (" + str(checksum) + ")")
            if countChecksumErrors >= maxChecksumErrors - 1:
              print("NOTE: No more checksum errors will be reported!")

        # For an incomplete block
        if not completeBlock:
          # Check basic information
          if sessionId != globalSessionId and (modifyFlags & 4 != 0):
            x = sessionId ^ globalSessionId
            bitValue = singleBit(x)
            if bitValue:
              print("Corrected probable corrupted session ID: " + str(sessionId) + " (" + str(globalSessionId) + ") ^" + hex(x))
              sessionId ^= bitValue
              countCorrectedSession += 1
            else:
              print("Uncorrected mismatched session ID: " + str(sessionId) + " (" + str(globalSessionId) + ") ^" + hex(x))
              
          # Check basic information
          nextSequence = None
          if len(data) > 0:
            lastItem = data[len(data) - 1]
            nextSequence = lastItem[idxSequenceId] + 1
          if nextSequence != None and sequenceId != nextSequence and (modifyFlags & 8 != 0):
            x = sequenceId ^ nextSequence
            bitValue = singleBit(x)
            if bitValue:
              print("Corrected possible corrupted sequence ID: " + str(sequenceId) + " (" + str(nextSequence) + ") ^" + hex(x))
              sequenceId ^= bitValue
              countCorrectedSequence += 1
            else:
              # print("NOTE: Non-consecutive sequence ID: " + str(sequenceId) + " (" + str(nextSequence) + ") ^" + hex(x))
              pass
        
        item = [0, 0, 0, fileOffset, blockLength, checksumOk]
        item[idxSession] = sessionId
        item[idxTimestamp] = timestamp
        item[idxSequenceId] = sequenceId
        
        data.append(tuple(item))
        break
      # Check if this is a header block
      if block[o] == ord('M') and block[o + 1] == ord('D') and block[o + 2] == 0xfc and block[o+3] == 0x03:
        fileOffset = i * sectorSize + o
        blockLength = headerSize - o
        deviceId = unpack('<H', block[o+5:o+7])[0]			# @ 5  +2   Device identifier
        sessionId = unpack('<I', block[o+7:o+11])[0]
        deviceIdUpper = unpack('<H', block[o+11:o+13])[0]		# @ 11  +2   Upper device identifier
        if deviceIdUpper == 0xffff:
            deviceIdUpper = 0x0000
        print("Found header with session ID: " + str(sessionId) + " (device=" + str(deviceId) + ")")
        globalDeviceId = (deviceIdUpper << 16) + deviceId
        globalSessionId = sessionId
        metadata.append((fileOffset, blockLength, sessionId))
        break
  
  print("Found: " + str(len(metadata)) + " raw metadata block(s), " + str(len(data)) + " raw data block(s)")
  print("Corrected: " + str(countCorrectedSession) + " session IDs, and " + str(countCorrectedSequence) + " sequences")
  
  if reorder:
    print("Sorting data blocks...")
    data = sorted(data)
  
  print("Writing output: ", outputFile)
  with open(outputFile, 'wb') as fo:
    if len(metadata) > 0:
      if len(metadata) > 1:
        print("WARNING: Multiple header blocks found, the first one will be used (some readers may not parse the whole data if the data block session-id does not match the one in the header")
      fileOffset = metadata[0][0]
      blockLength = metadata[0][1]
      globalSessionId = metadata[0][2]
      print("Using session ID: " + str(globalSessionId))
      block = bytearray(fileData[fileOffset:fileOffset + headerSize])  # blockLength
      if blockLength < headerSize:
        # Pad block with missing bytes
        missingBytes = headerSize - blockLength
        for o in range(missingBytes):
          block[blockLength + o] = 0xff
      fo.write(block)
    else:
      print("WARNING: No header block found, this utility does not yet create one, so the output file will not be valid")
    
    totalSamples = 0
    numOffsetSectors = 0
    totalMissingSamples = 0 # from offsets (not skipped sectors)
    numSkippedSectors = 0
    numDuplicates = 0
    numOtherSession = 0
    numOutOfSequence = 0
    numBackwards = 0
    numWritten = 0
    lastData = None
    lastPerc = -1
    for i in range(len(data)):
      # Progress
      perc = (100 * i) // len(data)
      if 5 * (perc // 5) != lastPerc:
        print("..." + str(i) + "/" + str(len(data)) + " = " + str(perc) + "%...")
        lastPerc = perc

      sessionId = data[i][idxSession]
      timestamp = data[i][idxTimestamp]
      sequenceId = data[i][idxSequenceId]
      fileOffset = data[i][3]
      fileSector = data[i][3] // sectorSize
      blockLength = data[i][4]
      checksumOk = data[i][5]
      if lastData != None:
        prevSessionId = lastData[idxSession]
        prevTimestamp = lastData[idxTimestamp]
        prevSequenceId = lastData[idxSequenceId]
        prevFileSector = lastData[3] // sectorSize
      else:
        prevSessionId = None
        prevTimestamp = None
        prevSequenceId = None
        prevFileSector = None
      recalculateChecksum = False

      if (modifyFlags & 16) != 0 and not checksumOk:
        recalculateChecksum = True
      
      print("Timestamp: " + str(timestamp))
      
      block = bytearray(fileData[fileOffset:fileOffset + sectorSize]) # blockLength
      
      # Patch-in possibly updated values
      if (modifyFlags & 12) != 0:
        pack_into('<I', block, 6, sessionId)
        pack_into('<I', block, 10, sequenceId)
      
      # Resequence
      if (modifyFlags & 1) != 0:
        if prevSequenceId is not None and sequenceId != prevSequenceId + 1:
          nextSequence = prevSequenceId + 1
          print("Resequenced: " + str(sequenceId) + " -> " + str(nextSequence))
          sequenceId = nextSequence
          pack_into('<I', block, 10, sequenceId)
          recalculateChecksum = True
      
      # Trace
      if False:
        print("#" + str(i) + " session=" + str(sessionId) + " t=" + str(timestamp) +  " sequence=" + str(sequenceId) + " @" + str(fileOffset) + "+" + str(blockLength) + " ")
      
      if prevFileSector != None and fileSector != prevFileSector + 1:
        jump = prevFileSector - fileSector
        # print("NOTE: #" + str(i) + " Non-consecutive sectors (advanced by " + str(jump) + "), next was " + str(prevFileSector+1) + " at " + str((prevFileSector+1) * sectorSize) + "")
      
      if lastData != None and timestamp < prevTimestamp:
        print("WARNING: #" + str(i) + " Found a jump back in time: " + str(prevTimestamp) + " to " + str(timestamp))
        numBackwards += 1
        
      if lastData != None and sessionId == prevSessionId and sequenceId == prevSequenceId and timestamp == prevTimestamp:
        # print("WARNING: #" + str(i) + " Ignoring a duplicate sector: session=" + str(sessionId) + " sequence=" + str(sequenceId) + " timestamp=" + str(timestamp))
        numDuplicates += 1
        continue
        
      if lastData != None and sessionId == prevSessionId and sequenceId != prevSequenceId + 1:
        missing = sequenceId - prevSequenceId - 1
        if missing > 0 and missing < 10:
          print("WARNING: #" + str(i) + " Missing " + str(missing) + " sectors, skipped from sequence id: " + str(prevSequenceId) + " to " + str(sequenceId))
          numSkippedSectors += missing
        else:
          print("WARNING: #" + str(i) + " Found a non-consecutive sequence ID in session " + str(sessionId) + ": " + str(prevSequenceId) + " to " + str(sequenceId))
          numOutOfSequence += 1
        
      if sessionId != globalSessionId:
        print("WARNING: #" + str(i) + " Mismatched session ID: " + str(sessionId) + " but header is " + str(globalSessionId) + " file offset " + str(fileOffset))
        numOtherSession += 1
        # Resequence
        if (modifyFlags & 2) != 0:
            pack_into('<I', block, 6, globalSessionId)
            recalculateChecksum = True
      
      missingBytes = 0
      if blockLength < sectorSize:
        # Pad block with missing bytes
        missingBytes = sectorSize - blockLength
        print("NOTE: Missing bytes: " + str(missingBytes))
        for o in range(missingBytes):
          block[blockLength + o] = 0
      
      # Determine number of bytes per sample
      numAxesBPS = block[25]
      channels = (numAxesBPS >> 4) & 0x0f
      bytesPerAxis = numAxesBPS & 0x0f
      bytesPerSample = 0
      if bytesPerAxis == 0 and channels == 3:
        bytesPerSample = 4
      elif bytesPerAxis > 0 and channels > 0:
        bytesPerSample = bytesPerAxis * channels
      
      if missingBytes > 0:
        # Calculate number of missing samples
        missingSamples = 0
        if missingBytes > 2 and bytesPerSample > 0: # after checksum
          missingSamples = (missingBytes - 2 + bytesPerSample - 1) // bytesPerSample
        
        print("NOTE: Missing samples: " + str(missingSamples) + " -- " + bytesPerSample + " bytes-per-sample")
          
        # Are the any missing samples?
        if missingSamples > 0:
          totalMissingSamples += missingSamples
          numOffsetSectors += 1
          sampleOffset = 30 + 480 - (missingSamples * bytesPerSample)
          # TODO: Fill in mean of samples?
        
        # Checksum invalid if any missing bytes
        if missingBytes > 0:
          block[510] = 0x00
          block[511] = 0x00
          recalculateChecksum = True
      
      # Recalculate checksum
      if recalculateChecksum:
        finalChecksum = 0x0000
        # Flag can be set to set recomputed checksums to zero
        if (modifyFlags & 32) == 0:
          pack_into('<H', block, 510, 0)
          checksumAtZero = checksum512(block)
          finalChecksum = (~checksumAtZero + 1) & 0xffff
        pack_into('<H', block, 510, finalChecksum)

      fo.write(block)
      lastData = data[i]
      numWritten += 1
      if bytesPerSample != 0:
        totalSamples += (480 / bytesPerSample)

    print("Total input checksum errors: " + str(countChecksumErrors))
    print("Total input sectors: " + str(numSectors))
    print("Wrote " + str(numWritten) + " sectors")
    print("Duplicates: " + str(numDuplicates))
    print("Out-of-sequence: " + str(numOutOfSequence))
    print("Backward jumps: " + str(numBackwards))
    print("Missing sectors: " + str(numSkippedSectors))
    print("Mismatching Session ID: " + str(numOtherSession))
    print("Number of sectors with initial damage: " + str(numOffsetSectors))
    print("Total number of missing samples from initial damage: " + str(totalMissingSamples))
    print("Total samples: " + str(totalSamples))
    if globalSessionId == None:
      print("WARNING: No header block was written (output file will not be valid)")


def main():
  print("Running...")

  method = 'sqt'
  inputFile = None
  outputFile = None
  modifyFlags = 0
  arg = 1
  while arg < len(sys.argv):
    if sys.argv[arg].startswith("-"):
      if sys.argv[arg] == "--method-sqt":
        # 'sqt' - sessions must be unique and sequence may not reset
        method = "sqt"
      elif sys.argv[arg] == "--method-tsq":
        # 'tsq' - clock not reset, session don't have to be unique, sequence may reset
        method = "tsq"
      elif sys.argv[arg] == "--method-stq":
        # 'stq' - clock may be reset, sessions were unique, sequence may reset
        method = "stq"
      elif sys.argv[arg] == "--modify":
        arg += 1
        modifyFlags = int(sys.argv[arg])
      elif sys.argv[arg] == "--method-none":
        # Special method to disable reordering
        method = "tsq+"
      else:
        print("ERROR: Unrecognized option: " + sys.argv[arg])
        return
    elif inputFile == None:
      inputFile = sys.argv[arg]
    elif outputFile == None:
      outputFile = sys.argv[arg]
    else:
      print("ERROR: Unrecognized positional argument: " + sys.argv[arg])
      return
    arg += 1

  if inputFile is None:
    inputFile = "cwa-dump.img"
  if outputFile is None:
    outputFile = "cwa-recover.cwa"

  print("NOTE: Using input file:", inputFile)
  if not os.path.exists(inputFile):
    print("ERROR: Input file does not exist:", inputFile)
    return
  
  print("NOTE: Using output file:", outputFile)
  if os.path.exists(outputFile):
    print("ERROR: Output file already exists, must remove or use another output file:", outputFile)
    return
  
  return recoverCwa(inputFile, outputFile, method, modifyFlags)

if __name__ == "__main__":
  main()