#!/usr/bin/env python3 import array import argparse import bz2 import concurrent.futures import collections import csv import ctypes import ftplib import functools import glob import gzip import hashlib import http.client import inspect import io import itertools import json import logging import logging.handlers import lzma import math import multiprocessing import os import pty import random import re import shutil import signal import subprocess import sys import tarfile import time import tempfile import threading import traceback import urllib import urllib.error import urllib.parse import urllib.request import zipfile import zlib LOG = None SCRIPT_PATHNAME = None # this is a placeholder value. The real version will # be substituted once after calling `install_kraken.sh`. SCRIPT_VERSION = "#####=VERSION=#####" NCBI_REST_API = "api.ncbi.nlm.nih.gov" NCBI_SERVER = "ftp.ncbi.nlm.nih.gov" GREENGENES_SERVER = "greengenes.microbio.me" SILVA_SERVER = "ftp.arb-silva.de" # GTDB_SERVER = "data.gtdb.ecogenomic.org" GTDB_SERVER = "data.ace.uq.edu.au" AMBIGUOUS_TAXID = 2 ** 32 - 1 WRAPPER_ARGS_TO_BIN_ARGS = { "block_size": "-B", "classified_out": "-C", "confidence": "-T", "fast_build": "-F", "interleaved": "-S", "kmer_len": "-k", "max_db_size": "-M", "memory_mapping": "-M", "minimizer_len": "-l", "minimum_bits_for_taxid": "-r", "minimum_base_quality": "-Q", "minimum_hit_groups": "-g", "output": "-O", "paired": "-P", "protein": "-X", "quick": "-q", "report": "-R", "report_minimizer_data": "-K", "report_zero_counts": "-z", "skip_counts": "-s", "sub_block_size": "-b", "threads": "-p", "unclassified_out": "-U", "use_mpa_style": "-m", "use_names": "-n", "use_daemon": "-D", } class FTP: def __init__(self, server): self.ftp = ftplib.FTP(server, timeout=600) self.ftp.login() self.ftp.sendcmd("TYPE I") self.pwd = "/" self.server = server def _progress_bar(self, f, remote_size): pb = ProgressBar(remote_size, f.tell()) def inner(block): nonlocal f, remote_size, pb written = 0 while written < len(block): written += f.write(block[written:]) size_on_disk = f.tell() pb.progress(size_on_disk) LOG.debug( "{:s} {: >10s}\r".format( pb.get_bar(), format_bytes(size_on_disk) ) ) return inner def download(self, remote_dir, filepaths): if isinstance(filepaths, str): filepaths = [filepaths] number_of_files = len(filepaths) self.cwd(remote_dir) for index, filepath in enumerate(filepaths): mode = "ab" local_size = 0 remote_size = self.size(filepath) if os.path.exists(filepath): local_size = os.stat(filepath).st_size else: if os.path.basename(filepath) != filepath: os.makedirs(os.path.dirname(filepath), exist_ok=True) if local_size == remote_size: LOG.info( "Already downloaded {:s}\n".format(get_abs_path(filepath)) ) continue if local_size > remote_size: mode = "wb" url_components = urllib.parse.SplitResult( "ftp", self.server, os.path.join(remote_dir, filepath), "", "" ) url = urllib.parse.urlunsplit(url_components) if number_of_files == 1: LOG.info("Downloading {:s}\n".format(url)) else: LOG.info( "[{:d}/{:d}] Downloading {:s}\n".format( index + 1, number_of_files, url ) ) with open(filepath, mode) as f: while True: try: cb = self._progress_bar(f, remote_size) self.ftp.retrbinary( "RETR " + filepath, cb, rest=f.tell() ) break except KeyboardInterrupt: f.flush() self.close() sys.exit(1) except ftplib.all_errors: f.flush() self.reconnect() self.cwd(remote_dir) continue absolute_path = get_abs_path(filepath) local_filename, local_dirname = os.path.basename( absolute_path ), os.path.dirname(absolute_path) clear_console_line() LOG.info( "Saved {:s} to {:s}\n".format(local_filename, local_dirname) ) def cwd(self, remote_pathname): self.ftp.cwd(remote_pathname) self.pwd = remote_pathname def size(self, filepath): size = 0 while True: try: size = self.ftp.size(filepath) break except ftplib.error_temp: self.reconnect() continue return size def exists(self, filepath): while True: try: self.size(filepath) break except ftplib.error_perm as e: if e.args[0].find("No such file or directory"): return False raise return True def connect(self, server): self.ftp = ftplib.FTP(server) self.ftp.login() self.ftp.sendcmd("TYPE I") def reconnect(self): host = self.ftp.host self.ftp.close() self.connect(host) self.ftp.cwd(self.pwd) def host(self): return self.ftp.host def close(self): self.ftp.quit() class ProgressBar: def __init__(self, stop, current=0, width=30): self.stop = stop self.width = width self.current = current self.bar = list("-" * self.width) self.step = stop / self.width self.last_index = self._calculate_index() if self.current > 0: self.progress() def progress(self, amount=0, relative=False): if relative: self.current += amount else: self.current = amount if self.current > self.stop: self.current = self.stop index = self._calculate_index() for i in range(self.last_index, index): if i == 0: self.bar[i] = ">" else: self.bar[i - 1], self.bar[i] = "=", ">" self.last_index = index def get_bar(self): percentage = int(self.current / self.stop * 100) return "{:3d}% {:s}".format(percentage, "[" + "".join(self.bar) + "]") def _calculate_index(self): return math.floor(self.current / self.step) class NCBI_URI_Builder: def __init__(self, endpoint="genome", *path_components): path_components = list(path_components) for i, component in enumerate(path_components): if isinstance(component, list): component = ",".join(component) path_components[i] = urllib.parse.quote(component) self.filters = {} self.path = "/datasets/v2/{}/{}".format(endpoint, "/".join(path_components)) def assembly_source(self, source=None): if source: self.filters["filters.assembly_source"] = urllib.parse.quote(source) return self def assembly_levels(self, levels): if levels: self.filters["filters.assembly_level"] = levels return self def assembly_version(self, version=None): if version: self.filters["filters.assembly_version"] = version return self def exclude_paired_reports(self, exclude_pairs=False): if exclude_pairs: self.filters["filters.exclude_paired_reports"] = "true" return self def has_annotation(self, annotated=False): if annotated: self.filters["filters.has_annotation"] = "true" return self def search_text(self, text=None): if text: self.filters["filters.search_text"] = urllib.parse.quote(text) return self def reference_only(self, reference_only=False): if reference_only: self.filters["filters.reference_only"] = reference_only return self def page_size(self, size=None): if size: self.filters["page_size"] = size return self def page_token(self, token=None): if token: self.filters["page_token"] = token return self def include_annotation_type(self, annotation_type=None): if annotation_type: self.filters["include_annotation_type"] = annotation_type return self def set_filters_from_args(self, args): for k, v in vars(args).items(): if hasattr(self, k): self = getattr(self, k)(v) def build(self): filters = [] for k, v in self.filters.items(): if isinstance(v, list): for value in v: filters.append("{}={}".format(k, value)) else: filters.append("{}={}".format(k, v)) query = "&".join(filters) split = urllib.parse.SplitResult( scheme="", netloc="", path=self.path, query=query, fragment="" ) return urllib.parse.urlunsplit(split) def reset(self): self.filters.clear() def wrap_with_globals(f, log_queue, log_level, script_pathname, *args): global LOG global SCRIPT_PATHNAME LOG = Logger.setup_queue_logger(log_queue, log_level) SCRIPT_PATHNAME = script_pathname return f(*args) def clear_console_line(): LOG.debug("\33[2K\r") def count_lines(*filenames): lines = 0 for fname in filenames: with open(fname, "r") as f: for line in f: lines += 1 return lines def dwk2(): estimate_capacity = find_kraken2_binary("estimate_capacity") output = subprocess.check_output( [estimate_capacity, "-h"], stderr=subprocess.STDOUT ) for line in output.split(b"\n"): if line.startswith(b"Usage:"): return True if line.strip().endswith(b"") else False return False def get_binary_options(binary_pathname): options = [] proc = subprocess.Popen(binary_pathname, stderr=subprocess.PIPE) lines = proc.stderr.readlines() for line in lines: match = re.search(rb"\s(-\w)\s", line) if not match: continue options.append(match.group(1).decode()) return options def construct_seed_template(args): if int(args.minimizer_len / 4) < args.minimizer_spaces: LOG.error( "Number of minimizer spaces, {}, exceeds max for " "minimizer length, {}; max: {}\n".format( args.minimizer_spaces, args.minimizer_len, int(args.minimizer_len / 4), ) ) sys.exit(1) return ( "1" * (args.minimizer_len - 2 * args.minimizer_spaces) + "01" * args.minimizer_spaces ) def copy_globals(queue, level, script_pathname): global LOG global SCRIPT_PATHNAME LOG = Logger.setup_queue_logger(queue, level) SCRIPT_PATHNAME = script_pathname def future_raised_exception(future): return future.done() and future.result() is None def url_join(netloc, scheme="https", path="", query="", fragment=""): split_result = urllib.parse.SplitResult( scheme, netloc, path, query, fragment ) return urllib.parse.urlunsplit(split_result) def execute_in_process_pool(func, num_processes, *args): pass def download_and_process_blast_volumes(args): download_files_from_manifest( NCBI_SERVER, args.threads, resume=args.resume ) extraction_futures = [] tarballs_and_converted_volumes = [] with concurrent.futures.ProcessPoolExecutor( max_workers=1 ) as pool: with open("manifest.txt", "r") as in_file: tarballs = in_file.readlines() f = functools.partial( wrap_with_globals, extract_blast_db_files, LOG.get_queue(), LOG.get_level(), SCRIPT_PATHNAME ) LOG.info( "Extracting index (.nin), header (.nhr), and" " sequence files (.nsq) from tarballs\n" ) for tarball in tarballs: tarball = os.path.abspath(tarball) f = pool.submit(extract_blast_db_files, tarball.strip()) extraction_futures.append(f) for future in concurrent.futures.as_completed(extraction_futures): result = future.result() tarballs_and_converted_volumes.append(result) LOG.info("Finished extracting files from {}\n".format(tarball)) for tarball, volume in tarballs_and_converted_volumes: LOG.info("Converting BLAST volume {} to FASTA\n".format(volume)) convert_blast_to_fasta(args, volume, tarball) LOG.info( "Finished converting BLAST volume {} to FASTA\n" .format(volume) ) library_extension = ".faa" if args.protein else ".fna" library_filename = "library" + library_extension LOG.info("Generating {} from converted volumes\n".format(library_filename)) with open(library_filename, "w") as lib, \ open("prelim_map.txt", "w") as plm: for _, volume in sorted(tarballs_and_converted_volumes): if not os.path.exists(volume + library_extension): LOG.error( "Missing volume: {}, " .format(volume + library_extension) ) sys.exit(1) with open(volume + library_extension, "r") as in_file: shutil.copyfileobj(in_file, lib) with open(volume + "_prelim_map.txt", "r") as in_file: shutil.copyfileobj(in_file, plm) def create_manifest_for_blast_db(db_name, volume_numbers, protein=False): suffix = "-prot-metadata.json" if protein else "-nucl-metadata.json" json_filename = "blast/db/" + db_name + suffix http_download_file2(NCBI_SERVER, [json_filename]) json_filename = os.path.abspath(json_filename) with open(json_filename, "r") as in_file: data = json.load(in_file) with open("manifest.txt", "w") as out_file: for volume in data["files"]: match = re.search(r"\.(\d+)\.", volume) if match: volume_number = match.group(1) if int(volume_number) not in volume_numbers: continue path = urllib.parse.urlsplit(volume).path out_file.write(path[1:] + "\n") def extract_blast_db_files(tarball_pathname): extract_dirname = os.path.dirname(tarball_pathname) volume = None with tarfile.open(tarball_pathname, "r:gz") as tar: for member in tar.getnames(): if member.endswith(("nsq", "nin", "nhr")): if not volume: volume = os.path.splitext(member)[0] tar.extract(member, extract_dirname) return ( tarball_pathname, os.path.join(extract_dirname, volume) ) def convert_blast_to_fasta(args, volume, tarball): extension = ".faa" if args.protein else ".fna" volume_dirname = os.path.dirname(volume) volume_basename = os.path.basename(volume) tarball = os.path.basename(tarball) tmp_fasta_filename = volume + extension + ".tmp" fasta_filename = volume + extension remote_filepath = url_join( NCBI_SERVER, path="blast/db/" + tarball ) blast_to_fasta_bin = find_kraken2_binary("blast_to_fasta") # blast_to_fasta_argv = "" proc = subprocess.Popen([blast_to_fasta_bin, "-s", "-t", volume]) if proc.wait() != 0: LOG.error( "Encountered an error while converting BLAST format to FASTA\n" ) sys.exit(1) with open(fasta_filename, "r") as in_file: prelim_map_name = os.path.join( volume_dirname, volume_basename + "_prelim_map.txt" ) with open(prelim_map_name, "w") as out_file: scan_fasta_file( in_file, out_file, lenient=True, sequence_to_url=remote_filepath ) if not args.no_masking: shutil.move(fasta_filename, tmp_fasta_filename) mask_files( [tmp_fasta_filename], fasta_filename, args.masker_threads, args.protein ) os.remove(tmp_fasta_filename) def wrapper_args_to_binary_args(opts, argv, binary_args): for k, v in vars(opts).items(): if k not in WRAPPER_ARGS_TO_BIN_ARGS: continue if WRAPPER_ARGS_TO_BIN_ARGS[k] not in binary_args: continue if v is False: continue if v is None: continue if v is True: argv.append(WRAPPER_ARGS_TO_BIN_ARGS[k]) else: argv.extend([WRAPPER_ARGS_TO_BIN_ARGS[k], str(v)]) def find_kraken2_binary(name): # search the OS PATH if "PATH" in os.environ: for dir in os.environ["PATH"].split(":"): if os.path.exists(os.path.join(dir, name)): return os.path.join(dir, name) # search for binary in the same directory as wrapper script_parent_directory = get_parent_directory(SCRIPT_PATHNAME) if os.path.exists(os.path.join(script_parent_directory, name)): return os.path.join(script_parent_directory, name) # if called from within kraken2 project root, search the src dir project_root = get_parent_directory(script_parent_directory) if "src" in os.listdir(project_root) and name in os.listdir( os.path.join(project_root, "src") ): return os.path.join(project_root, os.path.join("src", name)) # not found in these likely places, exit LOG.error("Unable to find {:s}, exiting\n".format(name)) sys.exit(1) def get_parent_directory(pathname): if len(pathname) == 0: return None pathname = os.path.abspath(pathname) if len(pathname) > 1 and pathname[-1] == os.path.sep: return os.path.dirname(pathname[:-1]) return os.path.dirname(pathname) def find_database(database_name): database_path = None if not os.path.isdir(database_name): if "KRAKEN2_DB_PATH" in os.environ: for directory in os.environ["KRAKEN2_DB_PATH"].split(":"): if os.path.exists(os.path.join(directory, database_name)): database_path = os.path.join(directory, database_name) break else: if database_name in os.listdir(os.getcwd()): database_path = database_name else: database_path = os.path.abspath(database_name) if database_path: for db_file in ["taxo.k2d", "hash.k2d", "opts.k2d"]: if not os.path.exists(os.path.join(database_path, db_file)): return None return database_path def remove_files(filepaths, forked=False): total_size = 0 for fname in filepaths: if not os.path.exists(fname): continue elif os.path.isdir(fname): with os.scandir(fname) as iter: directories = [] for entry in iter: if entry.is_dir(): directories.append(entry.path) else: total_size += os.path.getsize(entry.path) LOG.info("Removing {}\n".format(entry.path)) os.remove(entry.path) if not forked and len(directories) >= 4: total_size += remove_files_parallel(directories) else: total_size += remove_files(directories, forked) for directory in directories: shutil.rmtree(directory) else: LOG.info("Removing {}\n".format(fname)) total_size += os.path.getsize(fname) os.remove(fname) return total_size def remove_files_parallel(filepaths): total_size = 0 with concurrent.futures.ProcessPoolExecutor( max_workers=4, ) as pool: futures = [] f = functools.partial( wrap_with_globals, remove_files, LOG.get_queue(), LOG.get_level(), SCRIPT_PATHNAME ) for fname in filepaths: if not os.path.exists(fname): continue future = pool.submit(f, [fname], True) futures.append(future) for future in concurrent.futures.as_completed(futures): total_size += future.result() return total_size def get_taxid_from_seqid(seqid): taxid = None match = re.search(r"(?:^|\|)kraken:taxid\|(\d+)", seqid) if match: taxid = match.group(1) elif re.match(r"^\d+$", seqid): taxid = seqid if not taxid: match = re.search(r"(?:^|\|)([A-Z]+_?[A-Z0-9]+)(?:\||\b|\.)", seqid) if match: taxid = match.group(1) return taxid def hash_string(string): md5 = hashlib.md5() md5.update(string.encode()) return md5.hexdigest() def hash_file(filename, buf_size=8192): LOG.info("Calculating MD5 sum for {}\n".format(filename)) md5 = hashlib.md5() with open(filename, "rb") as in_file: while True: data = in_file.read(buf_size) if not data: break md5.update(data) digest = md5.hexdigest() LOG.info("MD5 sum of {} is {}\n".format(filename, digest)) return digest # This function is part of the Kraken 2 taxonomic sequence # classification system. # # Reads multi-FASTA input and examines each sequence header. Headers are # OK if a taxonomy ID is found (as either the entire sequence ID or as part # of a "kraken:taxid" token), or if something looking like an accession # number is found. Not "OK" headers will are fatal errors unless "lenient" # is used. # # Each sequence header results in a line with three tab-separated values; # the first indicating whether third column is the taxonomy ID ("TAXID") or # an accession number ("ACCNUM") for the sequence ID listed in the second # column. # def scan_fasta_file( in_file, out_file, lenient=False, sequence_to_url=None ): LOG.info("Generating prelim_map.txt for {}.\n".format(in_file.name)) iterator = in_file iterator_is_dict = False if type(sequence_to_url) is dict: iterator = sequence_to_url iterator_is_dict = True for line in iterator: if not line.startswith(">"): continue remote_filepath = sequence_to_url if iterator_is_dict: remote_filepath = sequence_to_url[line] for match in re.finditer(r"(?:^>|\x01)(\S+)(?: (.*))?", line): seqid = match.group(1) taxid = get_taxid_from_seqid(seqid) comment = match.group(2) or "" if not taxid: if lenient: continue else: sys.exit(1) if re.match(r"^\d+$", taxid): out_file.write( "TAXID\t{:s}\t{:s}\t{:s}\t{:s}\n".format( seqid, taxid, comment, remote_filepath ) ) else: out_file.write( "ACCNUM\t{:s}\t{:s}\t{:s}\t{:s}\n".format( seqid, taxid, comment, remote_filepath ) ) LOG.info( "Finished generating prelim_map.txt for {}.\n".format(in_file.name) ) # This function is part of the Kraken 2 taxonomic sequence # classification system. # # Looks up accession numbers and reports associated taxonomy IDs # # `lookup_list_file` is 1 2-column TSV file w/ sequence IDs and # accession numbers, and `accession_map_files` is a list of # accession2taxid files from NCBI. Output is tab-delimited lines, # with sequence IDs in first column and taxonomy IDs in second. # def lookup_accession_numbers( lookup_list_filename, out_filename, *accession_map_files ): target_lists = {} with open(lookup_list_filename, "r") as f: for line in f: line = line.strip() seqid, acc_num = line.split("\t") if acc_num in target_lists: target_lists[acc_num].append(seqid) else: target_lists[acc_num] = [seqid] initial_target_count = len(target_lists) with open(out_filename, "a") as out_file: for filename in accession_map_files: with open(filename, "r") as in_file: in_file.readline() # discard header line line_count = 0 for line in in_file: line_count += 1 line = line.strip() split = line.split("\t") if len(split) != 4: LOG.warning( "{}:{}-'{}' contains fewer than 4 fields\n" .format(filename, line_count, line) ) continue accession, with_version, taxid, gi = split if accession in target_lists: lst = target_lists[accession] del target_lists[accession] for seqid in lst: out_file.write(seqid + "\t" + taxid + "\n") if len(target_lists) == 0: break if len(target_lists) == 0: break if target_lists: LOG.warning( "{}/{} accession numbers remain unmapped, " "see unmapped_accessions.txt in {} directory\n" .format(len(target_lists), initial_target_count, os.path.abspath(os.curdir)) ) with open("unmapped_accessions.txt", "w") as f: for k in target_lists: f.write(k + "\n") def spawn_masking_subprocess(output_file, threads, protein=False): masking_binary = "segmasker" if protein else "k2mask" if "MASKER" in os.environ: masking_binary = os.environ["MASKER"] masking_binary = find_kraken2_binary(masking_binary) argv = masking_binary + " -outfmt fasta | sed -e '/^>/!s/[a-z]/x/g'" if masking_binary.find("k2mask") >= 0: # k2mask can run multithreaded argv = masking_binary + " -outfmt fasta -threads {} -r x".format( threads ) cwd = os.path.dirname(os.path.abspath(output_file.name)) p = subprocess.Popen( argv, shell=True, cwd=cwd, stdin=subprocess.PIPE, stdout=output_file ) return p # Mask low complexity sequences in the database def mask_files(input_filenames, output_filename, threads, protein=False): with open(output_filename, "wb") as fout: masker = spawn_masking_subprocess(fout, threads, protein) # number_of_files = len(input_filenames) for i, input_filename in enumerate(input_filenames): library_name = os.path.basename(os.getcwd()) if "blast" in output_filename: library_name = "blast" if library_name == "added": LOG.info( "Masking low-complexity regions of added " "library {}\n".format(input_filename) ) elif library_name == "blast": LOG.info( "Masking low-complexity regions for blast " "volume {}\n".format(output_filename) ) else: LOG.info( "Masking low-complexity regions of downloaded " "library {:s}\n".format(library_name) ) with open(input_filename, "rb") as fin: shutil.copyfileobj(fin, masker.stdin) # masker(fin, i + 1 == number_of_files) masker.stdin.close() if masker.wait() != 0: LOG.error("Error while masking {}\n".format(input_filename)) def add_file(args, filename, hashes): already_added = False filehash = None if filename in hashes: already_added = True filehash = hashes.get(filename) or hash_file(filename) destination = os.path.basename(filename) ext = ".faa" if args.protein else ".fna" base, _ = os.path.splitext(destination) destination = base + "_" + filehash + ext if already_added: LOG.info( "Already added " + filename + " to library. " "Please remove the entry from added.md5 if this" " is not the case.\n" ) return (filename, filehash, destination) LOG.info("Adding " + filename + " to library " + args.db + "\n") prelim_map_filename = "prelim_map_" + filehash + ".txt" with open(prelim_map_filename, mode="a") as out_file: with open(filename, "r") as in_file: scan_fasta_file( in_file, out_file, lenient=True, sequence_to_url=filename ) shutil.copyfile(filename, destination) if not args.no_masking: mask_files( [destination], destination + ".masked", threads=args.masker_threads, protein=args.protein, ) shutil.move(destination + ".masked", destination) LOG.info("Added " + filename + " to library " + args.db + "\n") return (filename, filehash, destination) def add_to_library(args): if not os.path.isdir(args.db): LOG.error("Invalid database: {:s}\n".format(args.db)) sys.exit(1) library_pathname = os.path.join(args.db, "library") added_pathname = os.path.join(library_pathname, "added") os.makedirs(added_pathname, exist_ok=True) args.files = [os.path.abspath(f) for f in args.files] os.chdir(added_pathname) hashes = {} if os.path.exists("added.md5"): with open("added.md5", "r") as in_file: hashes = dict([line.split()[:2] for line in in_file.readlines()]) with concurrent.futures.ProcessPoolExecutor( max_workers=args.threads ) as pool: futures = [] files = map(lambda f: glob.glob(f, recursive=True), args.files) for filename in itertools.chain(*files): f = functools.partial( wrap_with_globals, add_file, LOG.get_queue(), LOG.get_level(), SCRIPT_PATHNAME ) future = pool.submit(f, args, filename, hashes) if future_raised_exception(future): LOG.error( "Error while adding file to library\n" ) raise future.exception() futures.append(future) with open("added.md5", "a") as out_file: for future in concurrent.futures.as_completed(futures): result = future.result() (filename, filehash, destination) = result out_file.write( filename + "\t" + filehash + "\t" + destination + "\n" ) def make_manifest_from_assembly_summary( args, assembly_summary_file ): asm_level_regex = "|".join(args.assembly_levels).replace("_", " ") suffix = "_protein.faa.gz" if args.protein else "_genomic.fna.gz" manifest_to_taxid = {} for line in assembly_summary_file: if line.startswith("#"): continue fields = line.strip().split("\t") taxid, asm_level, ftp_path = fields[5], fields[11], fields[19] if not re.match(asm_level_regex, asm_level, re.IGNORECASE): continue if ftp_path == "na": continue remote_path = ftp_path + "/" + os.path.basename(ftp_path) + suffix url_components = urllib.parse.urlsplit(remote_path) local_path = url_components.path.replace("/", "", 1) manifest_to_taxid[local_path] = taxid with open("manifest.txt", "w") as f: for k in manifest_to_taxid: f.write(k + "\n") return manifest_to_taxid def assign_taxids(args, filepath, manifest_to_taxid, accession_to_taxid={}, filepath_to_url={}): absolute_filepath = os.path.abspath(filepath) sequences_added = 0 ch_added = 0 # taxid = manifest_to_taxid[filepath] out_filepath = "" if absolute_filepath.endswith(".gz"): out_filepath = os.path.splitext(absolute_filepath)[0] else: out_filepath = absolute_filepath + ".tmp" masker = None sequence_to_url = {} os.makedirs(os.path.dirname(absolute_filepath), exist_ok=True) with open(out_filepath, "w") as out_file: if not args.no_masking: masker = spawn_masking_subprocess( out_file, args.masker_threads, False ) opener = open if absolute_filepath.endswith(".gz"): opener = gzip.open with opener(absolute_filepath, "rt") as in_file: while True: line = in_file.readline() if line == "": break if line.startswith(">"): taxid = manifest_to_taxid[filepath] if not taxid: match = re.search(r"GC[AF]_[0-9]{9}\.\d+", line) if not match or match.group(0) not in accession_to_taxid: LOG.error( "Unable to assign taxid to sequence {} in file {}\n" .format(line, in_file.name) ) sys.exit(1) taxid = accession_to_taxid[match.group(0)] line = line.replace(">", ">kraken:taxid|" + taxid + "|", 1) sequence_to_url[line] = filepath_to_url[filepath] sequences_added += 1 else: ch_added += len(line) - 1 if not masker: out_file.write(line) else: masker.stdin.write(line.encode()) taxid = "" if out_filepath.endswith(".tmp"): shutil.move(out_filepath, absolute_filepath) if masker: masker.stdin.close() masker.wait() return (sequences_added, ch_added, sequence_to_url) def download_dataset_by_project(args, endpoint, identifiers): library_pathname = os.path.join(args.db, "library") os.makedirs(library_pathname, exist_ok=True) os.chdir(library_pathname) oldwd = os.path.curdir for identifier in identifiers: dirname = identifier.lower().replace(" ", "_") os.makedirs(dirname, exist_ok=True) os.chdir(dirname) download_and_process_accessions(args, endpoint, identifier) os.chdir(oldwd) def download_and_process_accessions(args, endpoint, identifier): api = http.client.HTTPSConnection(NCBI_REST_API) identifier = urllib.parse.quote(identifier) accession_to_taxid = {} builder = NCBI_URI_Builder(endpoint, identifier, "dataset_report") builder.set_filters_from_args(args) builder = builder.page_size(500) old_page_token = "" while True: api.request("GET", builder.build()) response = api.getresponse() if response.status == 429: LOG.warning( "Connection is being rate limited by NCBI, backing off\n" ) time.sleep(1) response.close() continue if response.status != 200: LOG.error( "Encountered an error while trying to gather " "accessions." ) sys.exit(1) results = response.readlines()[0] results = json.loads(results) response.close() if not results: LOG.error( "Could not find any accessions matching the query: {}\n" .format(identifier) ) sys.exit(1) for report in results["reports"]: accession_to_taxid[report["accession"]] =\ report["organism"]["tax_id"] if "next_page_token" in results \ and results["next_page_token"] != old_page_token: builder = builder.page_token(results["next_page_token"]) old_page_token = results["next_page_token"] else: break api.close() LOG.info( "Found {} accession(s) associated with {}\n" .format(len(accession_to_taxid), identifier) ) accessions = list(accession_to_taxid.keys()) accession_to_url =\ map_accessions_to_url_parallel(args, accessions) filepath_to_taxid = {} with open("manifest.txt", "w") as fout: for accession, url in accession_to_url.items(): # exclude the leading / filepath = urllib.parse.urlparse(url).path[1:] fout.write(filepath + "\n") taxid = accession_to_taxid[accession] filepath_to_taxid[filepath] = str(taxid) download_files_from_manifest(NCBI_SERVER, args.threads, resume=args.resume) filepath_to_url = {} for filepath in filepath_to_taxid.keys(): accession = re.search(r"GC[AF]_\d{9}\.\d+", filepath).group() filepath_to_url[filepath] = accession_to_url[accession] sequence_to_url = assign_taxid_to_sequences( args, filepath_to_taxid, filepath_to_url=filepath_to_url ) library_filename = "library.faa" if args.protein else "library.fna" with open(library_filename, "r") as in_file: with open("prelim_map.txt", "w") as out_file: out_file.write("# prelim_map for " + args.library + "\n") scan_fasta_file( in_file, out_file, sequence_to_url=sequence_to_url, ) def download_accessions(args, accession_to_taxid): accessions = sorted(list(accession_to_taxid.keys())) filepath_to_taxid = {} zip_file_list = [] zip_filename_to_accessions = {} extension = ".faa" if args.protein else ".fna" if os.path.exists("zips"): downloaded_accessions = {} for zip_filename in os.listdir("zips"): zip_filename = os.path.join("zips", zip_filename) if zip_filename.endswith(".zip"): with zipfile.ZipFile(zip_filename) as zip: for filename in zip.namelist(): match = re.search(r"GC[AF]_\d{9}\.\d+", filename) if match and filename.endswith(extension): accession = match.group() downloaded_accessions[accession] = zip_filename zip_file_list.append(zip_filename) if downloaded_accessions: unfetched_accessions = [] for accession in accessions: if accession in downloaded_accessions: LOG.info( "Already downloaded accession: {}, skipping\n" .format(accession) ) zip_filename = downloaded_accessions[accession] if zip_filename in zip_filename_to_accessions: zip_filename_to_accessions[zip_filename].append(accession) else: zip_filename_to_accessions[zip_filename] = [accession] else: unfetched_accessions.append(accession) accessions = unfetched_accessions partitions = [] if accessions: number_of_partitions = math.ceil(len(accessions) / 400) number_of_partitions = max(args.threads, number_of_partitions) partitions = partition_list(accessions, number_of_partitions) with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as pool: download_futures = [] for partition in partitions: download_futures.append( pool.submit(download_zip_from_ncbi, args, partition) ) unzip_futures = [] for future in concurrent.futures.as_completed(download_futures): unzip_futures.append( pool.submit( extract_fastas_from_zip_file, future.result(), args.protein ) ) for zip_filename, accessions in zip_filename_to_accessions.items(): unzip_futures.append( pool.submit( extract_fastas_from_zip_file, zip_filename, args.protein, accessions ) ) for future in concurrent.futures.as_completed(unzip_futures): accession_to_filepath = future.result() for accession, filepath in accession_to_filepath.items(): filepath_to_taxid[filepath] = str(accession_to_taxid[accession]) shutil.rmtree("ncbi_dataset", ignore_errors=True) return filepath_to_taxid def extract_fastas_from_zip_file(filename, protein, accessions_to_extract=None): accession_to_filepath = {} suffix = "_protein.faa" if protein else "_genomic.fna" with zipfile.ZipFile(filename) as zip: for entry in zip.namelist(): if entry.endswith(suffix): dir_components = ["genomes", "all"] filename = os.path.basename(entry) dirname = filename.replace(suffix, "") accession = re.search(r"GC[AF]_\d{9}\.\d+", filename).group() if accessions_to_extract and\ accession not in accessions_to_extract: continue modified_accession = accession.split(".")[0].replace("_", "") dir_components.extend(partition_list(modified_accession, 4)) dir_components.append(dirname) dir_components.append(filename) filepath = os.path.join("", *dir_components) entry_size = zip.getinfo(entry).file_size if os.path.exists(filepath) and os.stat(filepath).st_size == entry_size: LOG.info( "Already extracted {} from {}\n" .format(entry, os.path.abspath(filename)) ) else: LOG.info( "Extracting {} from {}\n".format(entry, os.path.abspath(filename)) ) zip.extract(entry, os.path.curdir) os.makedirs(os.path.dirname(filepath), exist_ok=True) LOG.debug( "Moving {} to {}\n".format( os.path.abspath(entry), os.path.abspath(filepath) ) ) shutil.move(entry, filepath) accession_to_filepath[accession] = filepath return accession_to_filepath def download_zip_from_ncbi(args, accessions): api = http.client.HTTPSConnection(NCBI_REST_API) accessions = ",".join(accessions) md5 = hash_string(accessions) os.makedirs("zips", exist_ok=True) filename = os.path.join("zips", md5 + ".zip") tmp_filename = filename + ".tmp" if os.path.exists(os.path.join("zip", filename)): LOG.info( "Already downloaded {} from NCBI which contains accessions: {}\n" .format(os.path.basename(filename), accessions) ) return filename LOG.info( "Downloading {} from NCBI containing the following accessions: {}\n" .format(os.path.basename(filename), accessions) ) accessions = urllib.parse.quote(accessions) annotation_type = "PROT_FASTA" if args.protein else "GENOME_FASTA" builder = NCBI_URI_Builder("accession", accessions, "download") builder = builder.include_annotation_type(annotation_type) api.request("GET", builder.build()) res = api.getresponse() with open(tmp_filename, "wb") as fout: shutil.copyfileobj(res, fout) res.close() shutil.move(tmp_filename, filename) LOG.info("Saved {} to {}\n".format( os.path.basename(filename), os.path.abspath(filename) )) return filename def partition_list(list, num_partitions): partitions = [] length = len(list) if length == 0: return [] step = math.ceil(length / num_partitions) if step == 0: return [] for i in range(0, length, step): end = i + step if end > len(list) or num_partitions == 1: end = len(list) partitions.append(list[i:end]) if num_partitions == 1: break num_partitions -= 1 return partitions def map_accessions_to_url(accessions, protein=False): api = http.client.HTTPSConnection(NCBI_REST_API) size = len(accessions) start = 0 step = 100 stop = min(size, step) base_uri = "/datasets/v2/genome/accession/{}/links" accession_to_url = {} while True: query = ",".join(accessions[start:stop]) query = urllib.parse.quote(query) headers = {'accept': 'application/json'} api.request("GET", base_uri.format(query), headers=headers) res = api.getresponse() if res.status != 200: res.close() time.sleep(1) api.connect() continue results = res.readlines()[0] res.close() results = json.loads(results) for entry in results["assembly_links"]: if entry["assembly_link_type"] == "FTP_LINK": filepath = get_download_path( entry["resource_link"], protein ) accession_to_url[entry["accession"]] =\ url_join(NCBI_SERVER, path=filepath) start, stop = stop, min(size, stop + step) if start == stop: break api.close() return accession_to_url def map_accessions_to_url_parallel(args, accessions): futures = [] accession_to_url = {} LOG.info("Fetching download links for accessions\n") # We are limited to 5 requests per second, so we limit the number of # workers accordingly. workers = min(args.threads, 5) with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as pool: for partition in partition_list(accessions, args.threads): future = pool.submit( map_accessions_to_url, partition, args.protein ) futures.append(future) for future in concurrent.futures.as_completed(futures): accession_to_url.update(future.result()) # accession_to_url = map_accessions_to_url(accessions, args.protein) LOG.info("Finished fetching download links for accessions\n") return accession_to_url def get_download_path(resource_link, protein): dir = os.path.basename(resource_link) filename = dir + "_genomic.fna.gz" resource_link += "/" if protein: filename = dir + "_protein.faa.gz" resource_link = urllib.parse.urljoin(resource_link, filename) path = urllib.parse.urlparse(resource_link).path # conn.request("HEAD", path) # response = conn.getresponse() # if response.status != 200: # LOG.error("Unable to find ...") # sys.exit(1) # response.close() return path def assign_taxid_to_sequences(args, manifest_to_taxid, accession_to_taxid={}, filepath_to_url={}): if args.no_masking: LOG.info("Assigning taxonomic IDs to sequences\n") else: LOG.info( "Assigning taxonomic IDs and masking sequences\n" ) library_filename = "library.faa" if args.protein else "library.fna" sequence_to_url = {} projects_added = 0 total_projects = len(manifest_to_taxid) sequences_added = 0 ch_added = 0 ch = "aa" if args.protein else "bp" out_line = progress_line( projects_added, total_projects, sequences_added, ch_added, ch ) LOG.debug("{:s}\r".format(out_line)) filepaths = sorted(manifest_to_taxid) with concurrent.futures.ProcessPoolExecutor( max_workers=args.threads ) as pool: futures = [] max_out_line_len = 0 for filepath in filepaths: f = functools.partial( wrap_with_globals, assign_taxids, LOG.get_queue(), LOG.get_level(), SCRIPT_PATHNAME ) future = pool.submit( f, args, filepath, manifest_to_taxid, accession_to_taxid, filepath_to_url ) if future_raised_exception(future): LOG.error( "Error encountered while assigning tax IDs\n" ) raise future.exception() futures.append(future) for future in concurrent.futures.as_completed(futures): result = future.result() sequences_added += result[0] ch_added += result[1] projects_added += 1 sequence_to_url.update(result[2]) out_line = progress_line( projects_added, total_projects, sequences_added, ch_added, ch ) max_out_line_len = max(len(out_line), max_out_line_len) padding = " " * (max_out_line_len - len(out_line)) LOG.debug("{:s}\r".format(out_line + padding)) if args.no_masking: LOG.info("Finished assigning taxonomic IDs to sequences\n") else: LOG.info("Finished assigning taxonomic IDs and masking sequences\n") LOG.info("Generating {:s}\n".format(library_filename)) with open(library_filename, "w") as out_file: for filepath in filepaths: if filepath.endswith(".gz"): filepath = os.path.splitext(filepath)[0] with open(filepath, "r") as in_file: shutil.copyfileobj(in_file, out_file) LOG.info("Finished generating {:s}\n".format(library_filename)) return sequence_to_url def progress_line(projects, total_projects, seqs, chars, ch): line = "Processed " if projects == total_projects: line += str(projects) else: line += "{:d}/{:d}".format(projects, total_projects) line += " project(s), {:d} sequence(s), ".format(seqs) prefix = None for p in ["k", "M", "G", "T", "P", "E"]: if chars >= 1024: prefix = p chars /= 1024 else: break if prefix: line += "{:.2f} {:s}{:s}".format(chars, prefix, ch) else: line += "{:.2f} {:s}".format(chars, ch) return line # The following three functions have dummy return values. This is # so that we can check whether a future stopped running as a # result of an exception, so that we can stop the main process # early. # def decompress_files(compressed_filenames, out_filename=None, buf_size=8192): # if isinstance(compressed_filenames, str): # compressed_filenames = [compressed_filenames] # if out_filename: # if os.path.exists(out_filename + ".tmp"): # os.remove(out_filename + ".tmp") # with open(out_filename + ".tmp", "ab") as out_file: # for filename in compressed_filenames: # with gzip.open(filename) as gz: # decompress_file(gz, out_file) # os.rename(out_filename + ".tmp", out_filename) # else: # for filename in compressed_filenames: # out_filename, ext = os.path.splitext(filename) # if os.path.exists(out_filename + ".tmp"): # os.remove(out_filename + ".tmp") # with gzip.open(filename) as gz: # with open(out_filename + ".tmp", "wb") as out: # decompress_file(gz, out, buf_size) # os.rename(out_filename + ".tmp", out_filename) # return True def decompress_files(compressed_filenames, out_filename=None, buf_size=8192): if isinstance(compressed_filenames, str): compressed_filenames = [compressed_filenames] if out_filename: if os.path.exists(out_filename + ".tmp"): os.remove(out_filename + ".tmp") with open(out_filename + ".tmp", "ab") as out_file: for filename in compressed_filenames: with open(filename, "rb") as gz: decompress_file(gz, out_file) os.rename(out_filename + ".tmp", out_filename) else: for filename in compressed_filenames: out_filename, ext = os.path.splitext(filename) if os.path.exists(out_filename + ".tmp"): os.remove(out_filename + ".tmp") with open(filename, "rb") as gz: with open(out_filename + ".tmp", "wb") as out: decompress_file(gz, out, buf_size) os.rename(out_filename + ".tmp", out_filename) return True def decompress_file(in_file, out_file, buf_size=8129): LOG.info( "Decompressing {:s}\n".format(os.path.join(os.getcwd(), in_file.name)) ) inflator = zlib.decompressobj(15 + 32) while True: data = in_file.read(buf_size) if not data: break inflated_data = inflator.decompress(data) out_file.write(inflated_data) # shutil.copyfileobj(in_file, out_file, buf_size) LOG.info( "Finished decompressing {:s}\n".format( os.path.join(os.getcwd(), in_file.name) ) ) return True def decompress_and_mask(filepath, masker_threads): out_filepath = os.path.splitext(filepath)[0] with open(out_filepath, "w") as out_file: masker = spawn_masking_subprocess(out_file, masker_threads) with open(filepath, "rb") as in_file: decompress_file(in_file, masker.stdin) masker.stdin.close() masker.wait() return True def download_log(filename, total_size=None): pb = None current_size = 0 def inner(block_number, read_size, size): nonlocal pb, current_size, total_size if not pb: pb = ProgressBar(total_size or size) current_size += read_size pb.progress(current_size) LOG.debug( "{:s} {: >10s}\r".format(pb.get_bar(), format_bytes(current_size)) ) return inner def http_download_file(url, local_name=None, call_back=None): if not local_name: local_name = urllib.parse.urlparse(url).path.split("/")[-1] else: local_name = os.path.abspath(local_name) os.makedirs(os.path.dirname(local_name), exist_ok=True) with urllib.request.urlopen(url) as conn: remote_size = int(conn.headers["Content-Length"]) local_size = ( os.stat(local_name).st_size if os.path.exists(local_name) else 0 ) if local_size == remote_size: LOG.info( "Already downloaded {:s}\n".format(get_abs_path(local_name)) ) return LOG.info("Beginning download of {:s}\n".format(url)) urllib.request.urlretrieve( url, local_name, reporthook=(call_back or download_log(local_name)) ) clear_console_line() LOG.info("Saved {:s} to {:s}\n".format(local_name, os.getcwd())) def http_download_file2(server, urls, save_to=None, md5sums=None): conn = None if isinstance(server, str): conn = http.client.HTTPSConnection(server, timeout=60) else: conn = server server = conn.host md5 = md5sums if md5sums else {} i = 0 num_urls = len(urls) skip_md5_check = False while i < num_urls: url = urls[i].strip() filename = os.path.basename(url) local_name = os.path.abspath(url) if save_to: local_name = os.path.join(save_to, filename) tmp_local_name = local_name + ".tmp" local_directory = os.path.dirname(local_name) os.makedirs(local_directory, exist_ok=True) try: if filename not in md5: checksums = [] remote_dirname = os.path.dirname(url) for md5_filename in ["md5checksums.txt", filename + ".md5", "MD5SUM.txt"]: # Check if file exists before trying to download. # This avoids NCBI sending BadStatusLine when # making the request. conn.request( "HEAD", "/" + remote_dirname + "/" + md5_filename ) response = conn.getresponse() if response.status == 200: response.close() # If we have found the file, then go ahead # and download. conn.request( "GET", "/" + remote_dirname + "/" + md5_filename ) response = conn.getresponse() checksums = response.readlines() response.close() break else: response.close() if len(checksums) > 0: for checksum in checksums: (md5sum, remote_filename) = checksum.split() remote_filename = os.path.basename( remote_filename.decode() ) md5[remote_filename] = md5sum.decode() if not skip_md5_check and os.path.exists(local_name)\ and filename in md5 and md5[filename] == hash_file(local_name): LOG.info( "Already downloaded {:s}\n".format( urllib.parse.urljoin(server, url) ) ) # Server can potentially end the connection while we # waiting for the md5 hash to be computed. We have # to reconnect to avoid failures when trying to # retrieve the file. conn.connect() i += 1 continue LOG.info("Beginning download of {:s}\n".format(server + "/" + url)) with open(tmp_local_name, "wb") as out_file: conn.request("GET", "/" + url) response = conn.getresponse() if response.status == 200: shutil.copyfileobj(response, out_file, 8192) elif response.status == 404: LOG.warning( "Cannot find file: {}.\n" "Please report this issue to NCBI.\n".format(url) ) else: LOG.error( "Error downloading file: {}.\n" "Reason: {}\n".format(url, response.reason) ) response.read() sys.exit(1) response.close() shutil.move(tmp_local_name, local_name) except (http.client.HTTPException, http.client.RemoteDisconnected) as e: LOG.warning( "Unable to download " + url + ". Reason: {}, will try again\n" .format(e) ) conn.close() time.sleep(0.1) conn.connect() continue # Check the MD5 sum of the downloaded file to make sure that it # was downloaded successfully. This prevents issues when # assigning tax IDs to files. local_md5sum = hash_file(local_name) if filename in md5: if md5[filename] != local_md5sum: LOG.warning( "The MD5 sum of {} does not match the MD5 provided" " by the server. The file will be downloaded again.\n" .format(local_name) ) # We have already confirmed that MD5 sum does not match # do not bother checking it again at the top of the loop. skip_md5_check = True conn.connect() continue else: LOG.info( "The remote and local MD5 sum of {} match\n" .format(local_name) ) LOG.info( "Saved {:s} to {:s}\n".format(filename, local_directory) ) # Reset the MD5 check for the next file. skip_md5_check = False i += 1 def make_file_filter(file_handle, regex): def inner(listing): path = listing.split()[-1] if path.endswith(regex): file_handle.write(path + "\n") return inner def move(src, dst): src = os.path.abspath(src) dst = os.path.abspath(dst) if os.path.isfile(src) and os.path.isdir(dst): dst = os.path.join(dst, os.path.basename(src)) shutil.move(src, dst) def get_manifest_and_md5sums(server, remote_directory, regex): ftp = ftplib.FTP(server) ftp.login() sstream = io.StringIO() ftp.cwd(remote_directory) ftp.retrlines("LIST", callback=make_file_filter(sstream, regex)) with open("manifest.txt", "w") as out: for line in sstream.getvalue().split(): out.write(urllib.parse.urljoin(remote_directory, line) + "\n") sstream.truncate(0) sstream.seek(0) ftp.cwd("/refseq/release/release-catalog") ftp.retrlines("LIST", callback=make_file_filter(sstream, "installed")) install_file = sstream.getvalue().strip() bstream = io.BytesIO() ftp.retrbinary("RETR " + install_file, callback=bstream.write) ftp.close() md5sums = {} for line in bstream.getvalue().split(b"\n"): if line.find(b"plasmid") == -1: continue (md5, filename) = line.split() md5sums[filename.decode()] = md5.decode() return md5sums def download_files_from_manifest( server, threads=1, manifest_filename="manifest.txt", filepath_to_taxid_table=None, md5sums=None, resume=False ): threads = min(threads, 12) with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool: with open(manifest_filename, "r") as f: filepaths = f.readlines() if resume: nonexistent_filepaths = [] for filepath in filepaths: abs_filepath = os.path.abspath(filepath.strip()) if not os.path.exists(abs_filepath): nonexistent_filepaths.append(filepath) filepaths = nonexistent_filepaths # We try to reduce the risk of a single thread downloading # many large files. random.shuffle(filepaths) partitions = [] futures = [] partitions = partition_list(filepaths, threads) for partition in partitions: future = pool.submit( http_download_file2, server, partition, None, md5sums ) futures.append(future) (done, not_done) = concurrent.futures.wait(futures) if len(not_done) != 0: LOG.error("Error encountered while trying to download files\n") for future in not_done: LOG.error(future.exception()) sys.exit(1) # Make sure that all files are downloaded. This has been an issue # with large collections like bacteria. for filepath in filepaths: if not os.path.exists(filepath.strip()): http_download_file2(server, [filepath], None, md5sums) def download_and_decompress(filename): http_download_file2( NCBI_SERVER, [filename], save_to=os.path.abspath(os.curdir) ) decompress_files(os.path.abspath(os.path.basename(filename))) def download_taxonomy(args): taxonomy_path = os.path.join(args.db, "taxonomy") os.makedirs(taxonomy_path, exist_ok=True) os.chdir(taxonomy_path) futures = [] with concurrent.futures.ProcessPoolExecutor( max_workers=2 ) as pool: if not args.skip_maps: if not args.protein: for subsection in ["gb", "wgs"]: filename = "pub/taxonomy/accession2taxid/" filename += "nucl_" + subsection + ".accession2taxid.gz" f = functools.partial( wrap_with_globals, download_and_decompress, LOG.get_queue(), LOG.get_level(), SCRIPT_PATHNAME ) future = pool.submit( f, filename ) if future_raised_exception(future): LOG.error( "Error encountered while downloading file\n" ) raise future.exception() futures.append(future) else: filename = "/pub/taxonomy/accession2taxid/" filename += "prot.accession2taxid.gz" f = functools.partial( wrap_with_globals, download_and_decompress, LOG.get_queue(), LOG.get_level(), SCRIPT_PATHNAME ) future = pool.submit(f, filename) if future_raised_exception(future): LOG.error( "Error encountered while downloading file\n" ) raise future.exception() futures.append(future) LOG.info("Downloading taxonomy tree data\n") filename = "pub/taxonomy/taxdump.tar.gz" http_download_file2( NCBI_SERVER, [filename], save_to=os.path.abspath(os.curdir) ) LOG.info("Untarring taxonomy tree data\n") with tarfile.open("taxdump.tar.gz", "r:gz") as tar: tar.extractall() LOG.info("Finished Untarring taxonomy tree data\n") concurrent.futures.wait(futures) def download_gtdb_taxonomy(args, files, md5s): taxonomy_path = os.path.join(args.db, "taxonomy") os.makedirs(taxonomy_path, exist_ok=True) os.chdir(taxonomy_path) LOG.info("Dowloading GTDB taxonomy for bacteria and archaea\n") http_download_file2( GTDB_SERVER, files, save_to=os.path.abspath(os.curdir), md5sums=md5s ) LOG.info("Finished downloading GTDB taxonomy for bacteria and archaea\n") def build_gtdb_taxonomy(in_file): rank_codes = { "d": "domain", "p": "phylum", "c": "class", "o": "order", "f": "family", "g": "genus", "s": "species", } accession_map = {} seen_it = collections.defaultdict(int) child_data = collections.defaultdict(lambda: collections.defaultdict(int)) for line in in_file: line = line.strip() accession, taxonomy_string = line.split("\t") start = accession.find("GCA") if start < 0: start = accession.find("GCF") accession = accession[start:] taxonomy_string = re.sub("(;[a-z]__)+$", "", taxonomy_string) accession_map[accession] = taxonomy_string seen_it[taxonomy_string] += 1 if seen_it[taxonomy_string] > 1: continue while True: match = re.search("(;[a-z]__[^;]+$)", taxonomy_string) if not match: break level = match.group(1) taxonomy_string = re.sub("(;[a-z]__[^;]+$)", "", taxonomy_string) key = taxonomy_string + level child_data[taxonomy_string][key] += 1 seen_it[taxonomy_string] += 1 if seen_it[taxonomy_string] > 1: break if seen_it[taxonomy_string] == 1: child_data["root"][taxonomy_string] += 1 id_map = {} next_node_id = 1 LOG.info("Generating nodes.dmp and names.dmp\n") with open("names.dmp", "w") as names_file: with open("nodes.dmp", "w") as nodes_file: bfs_queue = [["root", 1]] while len(bfs_queue) > 0: node, parent_id = bfs_queue.pop() display_name = node rank = None match = re.search("([a-z])__([^;]+)$", node) if match: rank = rank_codes[match.group(1)] display_name = match.group(2) rank = rank or "no rank" node_id, next_node_id = next_node_id, next_node_id + 1 id_map[node] = node_id names_file.write( "{:d}\t|\t{:s}\t|\t-\t|\tscientific name\t|\n".format( node_id, display_name ) ) nodes_file.write( "{:d}\t|\t{:d}\t|\t{:s}\t|\t-\t|\n".format( node_id, parent_id, rank ) ) children = ( sorted([key for key in child_data[node]]) if node in child_data else [] ) for node in children: bfs_queue.insert(0, [node, node_id]) with open("gtdb.accession2taxid", "w") as f: for accession in sorted([key for key in accession_map]): taxid = id_map[accession_map[accession]] accession_without_revision = accession.split(".")[0] f.write("{:s}\t{:s}\t{:d}\t-\n".format( accession_without_revision, accession, taxid )) def download_gtdb_genomes(args, remote_filepath, md5s): for directory in ["taxonomy", "library"]: os.makedirs(directory, exist_ok=True) filename = os.path.basename(remote_filepath) filepath = os.path.abspath(filename) http_download_file2( GTDB_SERVER, [remote_filepath], save_to=os.curdir, md5sums=md5s ) os.chdir("library") accession_to_filepath = {} filepaths_without_accession = [] ext = ".faa" if args.protein else ".fna" library = "" if filename.endswith(".tar.gz"): with tarfile.open(filepath, "r:gz") as tar: library = tar.getnames()[0] for member in tar.getmembers(): if member.isfile() and re.search(ext, member.name): if re.search("GC[AF]", member.name): filename = os.path.basename(member.name) accession = re.search( r"(GC[AF]_\d{9}\.\d+)", filename ).group(1) if os.path.exists(member.name)\ and member.size == os.stat(member.name).st_size: LOG.info( "Already extracted {}...skipping\n" .format(member.name) ) else: LOG.info("Extracting {}\n".format(member.name)) tar.extract(member) LOG.info( "Finished extracting {}\n" .format(member.name) ) accession_to_filepath[accession] =\ os.path.abspath(member.name) else: # We do not check if a file has already been extracted here LOG.info("Extracting {}\n".format(member.name)) tar.extract(member) LOG.info( "Finished extracting {}\n".format(member.name) ) filepaths_without_accession.append( os.path.abspath(member.name) ) else: library = filename.split(".")[0] filepaths_without_accession.append(os.path.abspath(filepath)) return (library, accession_to_filepath, filepaths_without_accession) def identity(object): return object def remap_secondary_taxids(taxids): new_taxids = {} partitions = math.ceil(len(taxids) / 500) conn = http.client.HTTPSConnection(NCBI_REST_API) for taxid_list in partition_list(taxids, partitions): endpoint = NCBI_URI_Builder( "taxonomy", "taxon", taxid_list ).build() conn.request("GET", endpoint) response = conn.getresponse() if response.status == 200: result = response.readlines()[0] response.close() for entry in json.loads( result, parse_float=identity, parse_int=identity )["taxonomy_nodes"]: query = entry["query"][0] taxid = entry["taxonomy"]["tax_id"] new_taxids[query] = taxid else: response.close() return new_taxids def read_gtdb_metadata(filenames): metadata = {} for filename in filenames: LOG.info("Reading NCBI tax IDs from {}\n".format(filename)) with gzip.open(filename, "rt") as in_file: reader = csv.DictReader(in_file, delimiter="\t") for row in reader: metadata[row["accession"]] = row["ncbi_taxid"] LOG.info("Finished reading NCBI tax IDs from {}\n".format(filename)) return metadata def find_and_remap_secondary_taxids(metadata): gtdb_assigned_taxids = set() for value in metadata.values(): gtdb_assigned_taxids.add(value) ncbi_taxids = set() with open("nodes.dmp", "r") as in_file: for line in in_file: taxid = line.split()[0] ncbi_taxids.add(taxid) taxids = list(gtdb_assigned_taxids.difference(ncbi_taxids)) LOG.info( "The following tax IDs were not found in nodes.dmp" " and need to be remapped\n" ) remapped_taxids = remap_secondary_taxids(taxids) for accession, taxid in metadata.items(): if taxid in remapped_taxids: LOG.info( "Remapping {} to {}\n".format(taxid, remapped_taxids[taxid]) ) metadata[accession] = remapped_taxids[taxid] return metadata def get_gtdb_latest_md5sums(path_prefix): try: md5_url = url_join( GTDB_SERVER, path=path_prefix + "/releases/latest/MD5SUM.txt" ) http_download_file(md5_url) except Exception: url = url_join( GTDB_SERVER, path=path_prefix + "/releases/latest/VERSION.txt" ) # remove the leading 'v' from the version number version = urllib.request.urlopen(url).readline().decode().strip()[1:] md5_url = url_join( GTDB_SERVER, path=path_prefix + "/releases/release" + version + "/MD5SUM.txt" ) http_download_file(md5_url) def get_needed_files(args, path_prefix): files_needed = collections.defaultdict(list) md5s = {} with open("MD5SUM.txt", "r") as in_file: # filename_regex = r"genomes|genes|fna" filename_regex = "|".join(args.gtdb_files) pattern = re.compile(filename_regex) candidate_files = [] if args.protein: filename_regex = r"protein|faa" for line in in_file: md5sum, filepath = line.split() filepath = urllib.parse.urljoin( path_prefix + "/releases/latest/", filepath ) # remove the release tag since they do not appear in the "latest" # file listings filepath = re.sub(r"_r\d+", "", filepath) if filepath.find("genomic_files") != -1: candidate_files.append(os.path.basename(filepath)) if re.search(r"taxonomy.*\.tsv$", filepath): files_needed["taxonomy"].append(filepath) if re.search(r".*metadata.*\.tsv.gz$", filepath): files_needed["metadata"].append(filepath) elif re.search(pattern, filepath): files_needed["fasta"].append(filepath) filepath = os.path.basename(filepath) md5s[filepath] = md5sum if len(files_needed["fasta"]) != len(args.gtdb_files): LOG.error("At least one of the files did not match: {}\n" .format(", ".join(args.gtdb_files))) LOG.error("Here is a list of candidates:\n{}\n" .format("\n".join(candidate_files))) sys.exit(1) return (md5s, files_needed) def build_gtdb_database(args): global GTDB_SERVER db_pathname = os.path.abspath(args.db) os.makedirs(db_pathname, exist_ok=True) os.chdir(db_pathname) path_prefix = "/public/gtdb/data" if args.gtdb_server != GTDB_SERVER: GTDB_SERVER = args.gtdb_server path_prefix = "" get_gtdb_latest_md5sums(path_prefix) md5s, files_needed = get_needed_files(args, path_prefix) download_gtdb_taxonomy(args, files_needed["taxonomy"], md5s) os.chdir(os.path.join(db_pathname, "taxonomy")) LOG.info("Merging Archaea and Bacteria taxonomies\n") with open("merged_taxonomy.tsv", "w") as file_out: for tax_filename in files_needed["taxonomy"]: tax_filename = os.path.basename(tax_filename) with open(tax_filename, "r") as file_in: shutil.copyfileobj(file_in, file_out) LOG.info("Finished merging Archaea and Bacteria taxonomies\n") accession_to_taxid = {} if not args.gtdb_use_ncbi_taxonomy: with open("merged_taxonomy.tsv", "r") as in_file: build_gtdb_taxonomy(in_file) else: for metadata in files_needed["metadata"]: url = url_join(GTDB_SERVER, path=metadata) http_download_file(url) metadata_filenames = map(os.path.basename, files_needed["metadata"]) metadata = read_gtdb_metadata(metadata_filenames) args.skip_maps = True download_taxonomy(args) metadata = find_and_remap_secondary_taxids(metadata) for accession, taxid in metadata.items(): accession_to_taxid[accession[3:]] = taxid workers = len(files_needed["fasta"]) futures = [] accession_to_filepath = {} # These files are not tied to a single accession, but # instead each sequence in the FASTA has its own accession. # os.chdir(os.path.join(db_pathname, "taxonomy")) if not args.gtdb_use_ncbi_taxonomy: with open("gtdb.accession2taxid", "r") as in_file: for line in in_file: base_accession, accession, taxid, gi = line.split("\t") accession_to_taxid[accession] = taxid with concurrent.futures.ProcessPoolExecutor( max_workers=workers ) as pool: for remote_filepath in files_needed["fasta"]: os.chdir(db_pathname) f = functools.partial( wrap_with_globals, download_gtdb_genomes, LOG.get_queue(), LOG.get_level(), SCRIPT_PATHNAME ) futures.append(pool.submit( f, args, remote_filepath, md5s )) for future in concurrent.futures.as_completed(futures): result = future.result() library = result[0] filepaths_without_accessions = [] filepath_to_url = {} for accession, filepath in result[1].items(): accession_to_filepath[accession] = filepath filepath_to_url[filepath] =\ url_join(GTDB_SERVER, path=remote_filepath) for filepath in result[2]: filepaths_without_accessions.append(filepath) filepath_to_url[filepath] =\ url_join(GTDB_SERVER, path=remote_filepath) filepath_to_taxid_table = {} library_pathname = os.path.join(db_pathname, os.path.join("library", library)) os.makedirs(library_pathname, exist_ok=True) os.chdir(library_pathname) for accession, filepath in accession_to_filepath.items(): filepath_to_taxid_table[filepath] =\ accession_to_taxid[accession] for filepath in filepaths_without_accessions: filepath_to_taxid_table[filepath] = "" sequence_to_url = assign_taxid_to_sequences( args, filepath_to_taxid_table, accession_to_taxid, filepath_to_url ) with open("library.fna", "r") as in_file: with open("prelim_map.txt", "w") as out_file: out_file.write("# prelim_map for {:s}\n".format(library)) scan_fasta_file( in_file, out_file, sequence_to_url=sequence_to_url, ) os.chdir(db_pathname) build_kraken2_db(args) def download_genomic_library(args): library_filename = "library.faa" if args.protein else "library.fna" library_pathname = os.path.join(args.db, "library") LOG.info("Adding {:s} to {:s}\n".format(args.library, args.db)) if args.library in [ "archaea", "bacteria", "viral", "fungi", "invertebrate", "plant", "human", "protozoa", "vertebrate_mammalian", "vertebrate_other" ]: library_pathname = os.path.join(library_pathname, args.library) os.makedirs(library_pathname, exist_ok=True) os.chdir(library_pathname) try: os.remove("assembly_summary.txt") except FileNotFoundError: pass remote_dir_name = args.library if args.library == "human": remote_dir_name = "vertebrate_mammalian/Homo_sapiens" try: if args.assembly_source == "all": # Download and merge assembly summaries from both RefSeq and GenBank refseq_url = "genomes/refseq/{:s}/assembly_summary.txt".format(remote_dir_name) genbank_url = "genomes/genbank/{:s}/assembly_summary.txt".format(remote_dir_name) # Download RefSeq assembly summary http_download_file2( NCBI_SERVER, [refseq_url], save_to=os.path.abspath(os.curdir) ) os.rename("assembly_summary.txt", "assembly_summary_refseq.txt") # Download GenBank assembly summary http_download_file2( NCBI_SERVER, [genbank_url], save_to=os.path.abspath(os.curdir) ) os.rename("assembly_summary.txt", "assembly_summary_genbank.txt") # Merge the two files, keeping the header from RefSeq only with open("assembly_summary.txt", "w") as merged_file: # Write RefSeq entries with open("assembly_summary_refseq.txt", "r") as refseq_file: merged_file.write(refseq_file.read()) # Write GenBank entries (skip header lines starting with #) with open("assembly_summary_genbank.txt", "r") as genbank_file: for line in genbank_file: if not line.startswith("#"): merged_file.write(line) # Clean up temporary files os.remove("assembly_summary_refseq.txt") os.remove("assembly_summary_genbank.txt") else: # Original behavior for "refseq" or "genbank" url = "genomes/{}/{:s}/assembly_summary.txt".format( args.assembly_source, remote_dir_name ) http_download_file2( NCBI_SERVER, [url], save_to=os.path.abspath(os.curdir) ) except urllib.error.URLError: LOG.error( "Error downloading assembly summary file for {:s}, " "exiting\n".format(args.library) ) sys.exit(1) if args.library == "human": with open("assembly_summary.txt", "r") as f1: with open("grc.txt", "w") as f2: for line in f1: if line.find("Genome Reference Consortium"): f2.write(line) os.rename("grc.txt", "assembly_summary.txt") with open("assembly_summary.txt", "r") as f: filepath_to_url = {} filepath_to_taxid_table = make_manifest_from_assembly_summary( args, f ) for filepath in filepath_to_taxid_table: filepath_to_url[filepath] = url_join(NCBI_SERVER, path=filepath) download_files_from_manifest( NCBI_SERVER, args.threads, filepath_to_taxid_table=filepath_to_taxid_table, resume=args.resume ) sequence_to_url = assign_taxid_to_sequences( args, filepath_to_taxid_table, filepath_to_url=filepath_to_url ) with open(library_filename, "r") as in_file: with open("prelim_map.txt", "w") as out_file: out_file.write("# prelim_map for " + args.library + "\n") scan_fasta_file( in_file, out_file, sequence_to_url=sequence_to_url, ) elif args.library in ["plasmid", "plastid", "mitochondrion"]: library_pathname = os.path.join(args.db, "library") library_pathname = os.path.join(library_pathname, args.library) library_filename = "library.faa" if args.protein else "library.fna" library_filename = os.path.join(library_pathname, library_filename) os.makedirs(library_pathname, exist_ok=True) os.chdir(library_pathname) pat = ".faa.gz" if args.protein else ".fna.gz" md5 = get_manifest_and_md5sums( NCBI_SERVER, "genomes/refseq/{}/".format(args.library), pat ) download_files_from_manifest( NCBI_SERVER, args.threads, md5sums=md5, resume=args.resume ) sequence_to_url = {} filenames = [] with open("manifest.txt", "r") as manifest: filenames = manifest.readlines() filenames = sorted(filenames) with concurrent.futures.ProcessPoolExecutor( max_workers=args.threads ) as pool: futures = [] for filename in filenames: filename = filename.strip() filename = os.path.abspath(filename) if not args.no_masking: f = functools.partial( wrap_with_globals, decompress_and_mask, LOG.get_queue(), LOG.get_level(), SCRIPT_PATHNAME ) future = pool.submit( f, filename, args.masker_threads ) if future_raised_exception(future): LOG.error( "Error encountered while decompressing" " or masking files\n" ) raise future.exception() futures.append(future) else: f = functools.partial( wrap_with_globals, decompress_files, LOG.get_queue(), LOG.get_level(), SCRIPT_PATHNAME ) future = pool.submit(f, [filename]) if future_raised_exception(future): LOG.error( "Error encountered while decompressing files\n" ) raise future.exception() futures.append(future) result = concurrent.futures.wait( futures, return_when=concurrent.futures.ALL_COMPLETED ) if len(result.not_done) > 0: LOG.error( "Encountered error while downloading Plasmid library\n" ) sys.exit(1) LOG.info("Generating {}\n".format(library_filename)) with open(library_filename, "w") as out_file: for filename in filenames: in_filename = os.path.splitext(filename)[0] with open(os.path.abspath(in_filename), "r") as in_file: for line in in_file: if line.startswith(">"): sequence_to_url[line.strip()] = filename out_file.write(line) LOG.info("Finished generating {}\n".format(library_filename)) with open(library_filename, "r") as in_file: with open("prelim_map.txt", "w") as out_file: out_file.write("# prelim_map for " + args.library + "\n") scan_fasta_file( in_file, out_file, sequence_to_url=sequence_to_url, ) elif args.library in ["core_nt", "nt", "env_nt", "nt_viruses", "nt_euk"]: library_pathname = os.path.join(library_pathname, args.library) os.makedirs(library_pathname, exist_ok=True) os.chdir(library_pathname) create_manifest_for_blast_db(args.library, args.blast_volumes) download_and_process_blast_volumes(args) elif args.library in ["UniVec", "UniVec_Core"]: if args.protein: LOG.error( "{:s} is available for nucleotide databases only\n".format( args.library ) ) sys.exit(1) library_pathname = os.path.join(library_pathname, args.library) os.makedirs(library_pathname, exist_ok=True) os.chdir(library_pathname) http_download_file2( NCBI_SERVER, ["pub/UniVec/" + args.library], save_to=os.path.abspath(os.curdir), ) special_taxid = 28384 LOG.info( "Assigning taxonomy ID of {:d} to all sequences\n".format( special_taxid ) ) with open(args.library, "r") as in_file: with open("library.fna", "w") as out_file: for line in in_file: if line.startswith(">"): line = re.sub( ">", ">kraken:taxid|" + str(special_taxid) + "|", line, ) out_file.write(line) with open("library.fna", "r") as in_file: with open("prelim_map.txt", "w") as out_file: out_file.write("# prelim_map for " + args.library + "\n") scan_fasta_file( in_file, out_file, sequence_to_url="pub/UniVec/" + args.library, ) else: if args.library.upper().startswith("GCF")\ or args.library.upper().startswith("GCA"): download_dataset_by_project(args, "genome/accession", [args.library.upper()]) elif args.library.upper().startswith("PRJ"): download_dataset_by_project(args, "genome/bioproject", [args.library.upper()]) else: download_dataset_by_project(args, "genome/taxon", [args.library]) if not args.no_masking\ and args.library in ["UniVec", "UniVec_Core"]: mask_files( [library_filename], library_filename + ".masked", args.masker_threads, args.protein ) shutil.move(library_filename + ".masked", library_filename) LOG.info("Added {:s} to {:s}\n".format(args.library, args.db)) def get_abs_path(filename): return os.path.abspath(filename) def is_compressed(filename): bzip_magic = b"\x42\x5A\x68" gzip_magic = b"\x1F\x8B" xz_magic = b"\xFD\x37\x7A\x58\x5A\x00" nbytes = len(xz_magic) with open(filename, "rb") as f: data = f.read(nbytes) if data.startswith((bzip_magic, gzip_magic, xz_magic)): return True return False def get_reader(filename): bzip_magic = b"\x42\x5A\x68" gzip_magic = b"\x1F\x8B" xz_magic = b"\xFD\x37\x7A\x58\x5A\x00" nbytes = len(xz_magic) with open(filename, "rb") as f: data = f.read(nbytes) if data.startswith(bzip_magic): return bz2.open elif data.startswith(gzip_magic): return gzip.open elif data.startswith(xz_magic): return lzma.open else: return open def read_from_files(filename1, filename2=None): reader1 = get_reader(filename1) reader2 = None if filename2 is not None: reader2 = get_reader(filename2) if reader2 is None: with reader1(filename1, "rb") as f: for seq in f: yield seq else: with reader1(filename1, "rb") as f1, reader2(filename2, "rb") as f2: for seq1, seq2 in itertools.zip_longest(f1, f2): if seq1 is None: LOG.error( "{} contains more sequences than {}".format( filename1, filename2 ) ) sys.exit(1) if seq2 is None: LOG.error( "{} contains more sequences than {}".format( filename2, filename1 ) ) sys.exit(1) yield (seq1, seq2) def write_to_fifo(filenames, fifo1=None, fifo2=None): if fifo2 is not None: with open(fifo1, "wb") as file1, open(fifo2, "wb") as file2: for fn1, fn2 in zip(filenames[0::2], filenames[1::2]): for seq1, seq2 in read_from_files(fn1, fn2): file1.write(seq1) file2.write(seq2) else: with open(fifo1, "wb") as file1: for fn in filenames: for seq in read_from_files(fn): file1.write(seq) def check_seqidmap(): LOG.info( "Checking if there are invalid taxid in seqid2taxid.map. " "These taxids will be logged if found and removed from the file\n" ) taxonomy_nodes = {} with open(os.path.join("taxonomy", "nodes.dmp"), "r") as fin: for entry in fin: taxid, parent_taxid = entry.split("\t|")[:2] taxonomy_nodes[taxid.strip()] = parent_taxid.strip() with open("seqid2taxid.map.new", "w") as fout: with open("seqid2taxid.map", "r") as fin: for line in fin: seqid, taxid = line.split("\t") taxid = taxid.strip() if taxid in taxonomy_nodes: fout.write(line) else: LOG.warning( "There is no entry for taxid, '{}', contained in, {}," "in nodes.dmp. Please contact NCBI about this\n" .format(taxid, seqid) ) shutil.move("seqid2taxid.map.new", "seqid2taxid.map") def suffix_to_multiplier(suffix): name_to_size = { 'byte': 1, 'kebibyte': 2 ** 10, 'mebibyte': 2 ** 20, 'gebibyte': 2 ** 30, 'tebibyte': 2 ** 40, 'kilobyte': 10 ** 3, 'megabyte': 10 ** 6, 'gigabyte': 10 ** 9, 'terabyte': 10 ** 12 } unit_to_size = { 'B': 1, 'KiB': 2 ** 10, 'KB': 10 ** 3, 'MiB': 2 ** 20, 'MB': 10 ** 6, 'GiB': 2 ** 30, 'GB': 10 ** 9, 'TiB': 2 ** 40, 'TB': 10 ** 12, } original_suffix = suffix if suffix in unit_to_size: return unit_to_size[suffix] if suffix.lower().endswith("s"): suffix = suffix.lower()[:-1] if suffix in name_to_size: return name_to_size[suffix] LOG.error("Unable to convert {} to a storage unit\n".format(original_suffix)) sys.exit(1) def parse_db_size(input): if input.isdigit(): return int(input) input = input.replace(" ", "") number = "".join(itertools.takewhile(str.isnumeric, input)) suffix = "".join(itertools.takewhile(str.isalpha, input[len(number):])) number = int(number) multiplier = suffix_to_multiplier(suffix) return number * multiplier def build_kraken2_db(args): if not os.path.isdir(get_abs_path(args.db)): LOG.error('Cannot find Kraken 2 database: "{:s}\n'.format(args.db)) sys.exit(1) os.chdir(args.db) if not os.path.isdir("taxonomy"): LOG.error("Cannot find taxonomy subdirectory in database\n") sys.exit(1) if not os.path.isdir("library"): LOG.error("Cannot find library subdirectory in database\n") sys.exit(1) prelim_map_filepaths = [] prelim_map_mtime = 0 if os.path.isdir("library"): glob_path = os.path.join("library", "*") prelim_map_filepaths = glob.glob( os.path.join(glob_path, "prelim_map*.txt") ) for prelim_map_filepath in prelim_map_filepaths: mtime = os.path.getmtime(prelim_map_filepath) if mtime > prelim_map_mtime: prelim_map_mtime = mtime if os.path.exists("seqid2taxid.map") and \ os.path.getmtime("seqid2taxid.map") > prelim_map_mtime: LOG.info( "A seqid2taxid.map already present and newer" " than any of the prelim_map.txt files, skipping\n" ) else: LOG.info("Concatenating prelim_map.txt files\n") with open("prelim_map.txt", "w") as out_file: for prelim_map_filepath in prelim_map_filepaths: with open(prelim_map_filepath, "r") as in_file: shutil.copyfileobj(in_file, out_file) if os.path.getsize("prelim_map.txt") == 0: os.remove("prelim_map.txt") LOG.error( "No preliminary seqid/taxid mapping files found, aborting\n" ) sys.exit(1) LOG.info("Finished concatenating prelim_map.txt files\n") LOG.info("Creating sequence ID to taxonomy ID map\n") with open("prelim_map.txt", "r") as in_file: with open("seqid2taxid.map.tmp", "w") as seqid2taxid_file: with open("accmap.tmp", "w") as accmap_file: for line in in_file: if line.startswith("#"): continue line = line.strip() new_line = "\t".join(line.split("\t")[1:3]) + "\n" if line.startswith("TAXID"): seqid2taxid_file.write(new_line) elif line.startswith("ACCNUM"): accmap_file.write(new_line) if os.path.getsize("accmap.tmp") > 0: accession2taxid_filenames = glob.glob("taxonomy/*.accession2taxid") if accession2taxid_filenames: lookup_accession_numbers( "accmap.tmp", "seqid2taxid.map.tmp", *accession2taxid_filenames ) else: LOG.error( "Accession to taxid map files are required to" " build this database.\n" ) LOG.error( "Run k2 download-taxonomy --db {:s} again".format(args.db) ) sys.exit(1) os.remove("accmap.tmp") move("seqid2taxid.map.tmp", "seqid2taxid.map") LOG.info("Created sequence ID to taxonomy ID map\n") check_seqidmap() estimate_capacity_binary = find_kraken2_binary("estimate_capacity") argv = [estimate_capacity_binary, "-S", construct_seed_template(args)] if args.protein: argv.append("-X") wrapper_args_to_binary_args( args, argv, get_binary_options(estimate_capacity_binary) ) fasta_filenames = glob.glob( os.path.join("library", os.path.join("*", "*.f[an]a")), recursive=False ) estimate = "" total_sequences = 0 if os.path.exists("estimated_capacity"): estimated_capacity_mtime = \ os.path.getmtime("estimated_capacity") seqid_to_taxid_map_mtime = os.path.getmtime("seqid2taxid.map") if estimated_capacity_mtime > seqid_to_taxid_map_mtime: LOG.info( "An estimated_capacity file exists and is newer " "than seqid2taxid.map , reading the estimated " "capacity from estimated_capacity file.\n" ) with open("estimated_capacity", "r") as in_file: lines = in_file.readlines() if len(lines) == 1: estimate = lines[0].strip() elif len(lines) == 2: estimate = lines[0].strip() total_sequences = int(lines[1].strip()) mapped_sequences = {} with open("seqid2taxid.map", "rb") as in_file: for line in in_file: sequence_name = line.split()[0] mapped_sequences[sequence_name] = True if estimate == "": if not dwk2(): argv.extend(fasta_filenames) LOG.info("Running: " + " ".join(argv) + "\n") proc = subprocess.Popen( argv, stdin=subprocess.PIPE, stdout=subprocess.PIPE ) if dwk2(): for filename in fasta_filenames: with open(filename, "rb") as in_file: for line in in_file: if line.startswith(b'>'): sequence_name = line.split()[0] if sequence_name[1:] in mapped_sequences: total_sequences += 1 proc.stdin.write(line) estimate = proc.communicate()[0].decode() proc.stdin.close() with open("estimated_capacity", "w") as out_file: out_file.write(estimate + str(total_sequences) + "\n") required_capacity = (int(estimate.strip()) + 8192) / args.load_factor LOG.info( "Estimated hash table requirement: {:s}\n".format( format_bytes(required_capacity * 4) ) ) if args.max_db_size: args.max_db_size = parse_db_size(args.max_db_size) if args.max_db_size < required_capacity * 4: args.max_db_size = int(args.max_db_size / 4) LOG.info( "Maximum hash table size of {}, specified and is" " lower than the calculated estimated capacity of {}\n" .format( format_bytes(args.max_db_size * 4), format_bytes(required_capacity * 4)) ) if os.path.isfile("hash.k2d"): LOG.info("Hash table already present, skipping build\n") else: LOG.info("Starting database build\n") build_db_bin = find_kraken2_binary("build_db") argv = [ build_db_bin, "-H", "hash.k2d.tmp", "-t", "taxo.k2d.tmp", "-o", "opts.k2d.tmp", "-n", "taxonomy", "-m", "seqid2taxid.map", "-c", str(required_capacity), "-S", construct_seed_template(args), ] if args.protein: argv.append("-X") wrapper_args_to_binary_args( args, argv, get_binary_options(build_db_bin) ) LOG.info("Running: " + " ".join(argv) + "\n") if total_sequences > 0: m_err, s_err = pty.openpty() cat_proc = subprocess.Popen( ["cat"] + fasta_filenames, stdout=subprocess.PIPE ) build_proc = subprocess.Popen( argv, stdin=cat_proc.stdout, stdout=s_err, stderr=s_err, ) thread = threading.Thread( target=read_from_stderr, args=(m_err, total_sequences) ) thread.start() else: build_proc = subprocess.Popen( argv, stdin=subprocess.PIPE, ) build_proc.communicate() cat_proc.stdout.close() if build_proc.returncode != 0: LOG.error( "Encountered error while building database: " "build process died unexpectedly\n" ) if total_sequences > 0: os.close(s_err) thread.join() if build_proc.wait() != 0 or cat_proc.wait() != 0: os.close(m_err) return move("hash.k2d.tmp", "hash.k2d") move("taxo.k2d.tmp", "taxo.k2d") move("opts.k2d.tmp", "opts.k2d") LOG.info("Finished building database\n") def decompress_with_zlib(filename): inflator = zlib.decompressobj(15 + 32) with open(filename, "rb") as infile: while True: data = infile.read(8196) if not data: break inflator.decompress(data) def read_from_stderr(fd, total_sequences): pb = ProgressBar(total_sequences) buffer = b"" processing = True while processing: data = os.read(fd, 1024) if len(data) == 0: processing = False data = buffer + data buffer = b"" for line in data.splitlines(True): fields = line.split() if line.startswith(b"Processed") and len(fields) > 1: buffer = b"" progress = int(fields[1]) pb.progress(progress) eol = '\r' if pb.current == total_sequences: processing = False eol = '\n' LOG.debug( "Processed:" + pb.get_bar() + " {}/{}{}" .format(progress, total_sequences, eol) ) elif line.endswith(b"\n"): LOG.debug(line.decode()) else: buffer = line os.close(fd) # Parses RDP sequence data to create Kraken taxonomy # and sequence ID -> taxonomy ID mapping def build_rdp_taxonomy(f): seqid_map = {} seen_it = {} child_data = {"root;no rank": {}} for line in f: if not line.startswith(">"): continue line = line.strip() seq_label, taxonomy_string = line.split("\t") seqid = seq_label.split(" ")[0] taxonomy_string = re.sub( "^Lineage=Root;rootrank;", "root;no rank;", taxonomy_string ) taxonomy_string = re.sub(";$", ";no rank", taxonomy_string) seqid_map[seqid] = taxonomy_string seen_it.setdefault(taxonomy_string, 0) seen_it[taxonomy_string] += 1 if seen_it[taxonomy_string] > 1: continue while True: match = re.search("(;[^;]+;[^;]+)$", taxonomy_string) if match is None: break level = match.group(1) taxonomy_string = re.sub(";[^;]+;[^;]+$", "", taxonomy_string) key = taxonomy_string + level child_data.setdefault(taxonomy_string, {}) seen_it.setdefault(taxonomy_string, 0) child_data[taxonomy_string].setdefault(key, 0) child_data[taxonomy_string][key] += 1 seen_it[taxonomy_string] += 1 if seen_it[taxonomy_string] > 1: break id_map = {} next_node_id = 1 with open("names.dmp", "w") as names_file: with open("nodes.dmp", "w") as nodes_file: bfs_queue = [["root;no rank", 1]] while len(bfs_queue) > 0: node, parent_id = bfs_queue.pop() match = re.search("([^;]+);([^;]+)$", node) if match is None: LOG.error( 'BFS processing encountered formatting eror, "{:s}"\n' .format(node) ) sys.exit(1) display_name, rank = match.group(1), match.group(2) if rank == "domain": rank = "superkingdom" node_id, next_node_id = next_node_id, next_node_id + 1 id_map[node] = node_id names_file.write( "{:d}\t|\t{:s}\t|\t-\t|\tscientific name\t|\n".format( node_id, display_name ) ) nodes_file.write( "{:d}\t|\t{:d}\t|\t{:s}\t|\t-\t|\n".format( node_id, parent_id, rank ) ) children = ( sorted([key for key in child_data[node]]) if node in child_data else [] ) for node in children: bfs_queue.insert(0, [node, node_id]) with open("seqid2taxid.map", "w") as f: for seqid in sorted([key for key in seqid_map]): taxid = id_map[seqid_map[seqid]] f.write("{:s}\t{:d}\n".format(seqid, taxid)) # Build the standard Kraken database def build_standard_database(args): # download_taxonomy(args) args.assembly_source = "refseq" args.assembly_levels = ["chromosome", "complete_genome"] args.resume = True for library in [ "archaea", "bacteria", "viral", "plasmid", "human", "UniVec_Core", ]: if library == "UniVec_Core" and args.protein: continue args.library = library download_genomic_library(args) build_kraken2_db(args) # Parses Silva taxonomy file to create Kraken taxonomy def build_silva_taxonomy(in_file): id_map = {"root": 1} with open("names.dmp", "w") as names_file: with open("nodes.dmp", "w") as nodes_file: names_file.write("1\t|\troot\t|\t-\t|\tscientific name\t|\n") nodes_file.write("1\t|\t1\t|\tno rank\t|\t-\t|\n") for line in in_file: line = line.strip() taxonomy_string, node_id, rank = line.split("\t")[:3] id_map[taxonomy_string] = node_id match = re.search("^(.+;|)([^;]+);$", taxonomy_string) if match: parent_name = match.group(1) display_name = match.group(2) if parent_name == "": parent_name = "root" parent_id = id_map[parent_name] or None if not parent_id: LOG.error('orphan error: "{:s}"\n'.format(line)) sys.exit(1) if rank == "domain": rank = "superkingdom" names_file.write( "{:s}\t|\t{:s}\t|\t-\t|\tscientific name\t|\n".format( node_id, display_name ) ) nodes_file.write( "{:s}\t|\t{:s}\t|\t{:s}\t|\t-\t|\n".format( node_id, str(parent_id), rank ) ) else: LOG.error('strange input: "{:s}"\n'.format(line)) sys.exit(1) # Build a 16S database from Silva data def build_16S_silva(args): args.db = os.path.abspath(args.db) os.makedirs(args.db, exist_ok=True) os.chdir(args.db) for directory in ["data", "taxonomy", "library"]: os.makedirs(directory, exist_ok=True) os.chdir("data") remote_directory = "/release_138_2/Exports" fasta_filename = "SILVA_138.2_SSURef_NR99_tax_silva.fasta.gz" taxonomy_prefix = "tax_slv_ssu_138.2" ftp = FTP(SILVA_SERVER) ftp.download(remote_directory, fasta_filename) ftp.download( remote_directory + "/taxonomy", taxonomy_prefix + ".acc_taxid.gz" ) decompress_files([taxonomy_prefix + ".acc_taxid.gz"]) ftp.download(remote_directory + "/taxonomy", taxonomy_prefix + ".txt.gz") with gzip.open(taxonomy_prefix + ".txt.gz", "rt") as f: build_silva_taxonomy(f) os.chdir(os.path.pardir) move(os.path.join("data", "names.dmp"), "taxonomy") move(os.path.join("data", "nodes.dmp"), "taxonomy") move( os.path.join("data", taxonomy_prefix + ".acc_taxid"), "seqid2taxid.map" ) with gzip.open(os.path.join("data", fasta_filename), "rt") as in_file: os.chdir("library") os.makedirs("silva", exist_ok=True) os.chdir("silva") with open("library.fna", "w") as out_file: for line in in_file: if not line.startswith(">"): line = line.replace("U", "T") out_file.write(line) if not args.no_masking: filename = "library.fna" mask_files( [filename], filename + ".masked", args.threads ) shutil.move(filename + ".masked", filename) os.chdir(args.db) build_kraken2_db(args) # Parses Greengenes taxonomy file to create Kraken taxonomy # and sequence ID -> taxonomy ID mapping # Input: gg_13_5_taxonomy.txt def build_gg_taxonomy(in_file): rank_codes = { "k": "superkingdom", "p": "phylum", "c": "class", "o": "order", "f": "family", "g": "genus", "s": "species", } seqid_map = {} seen_it = {} child_data = {"root": {}} for line in in_file: line = line.strip() seqid, taxonomy_string = line.split("\t") taxonomy_string = re.sub("(; [a-z]__)+$", "", taxonomy_string) seqid_map[seqid] = taxonomy_string seen_it.setdefault(taxonomy_string, 0) seen_it[taxonomy_string] += 1 if seen_it[taxonomy_string] > 1: continue while True: match = re.search("(; [a-z]__[^;]+$)", taxonomy_string) if not match: break level = match.group(1) taxonomy_string = re.sub("(; [a-z]__[^;]+$)", "", taxonomy_string) child_data.setdefault(taxonomy_string, {}) key = taxonomy_string + level seen_it.setdefault(taxonomy_string, 0) child_data[taxonomy_string].setdefault(key, 0) child_data[taxonomy_string][key] += 1 seen_it[taxonomy_string] += 1 if seen_it[taxonomy_string] > 1: break if seen_it[taxonomy_string] == 1: child_data["root"].setdefault(taxonomy_string, 0) child_data["root"][taxonomy_string] += 1 id_map = {} next_node_id = 1 with open("names.dmp", "w") as names_file: with open("nodes.dmp", "w") as nodes_file: bfs_queue = [["root", 1]] while len(bfs_queue) > 0: node, parent_id = bfs_queue.pop() display_name = node rank = None match = re.search("g__([^;]+); s__([^;]+)$", node) if match: genus, species = match.group(1), match.group(2) rank = "species" if re.search(" endosymbiont ", species): display_name = species else: display_name = genus + " " + species else: match = re.search("([a-z])__([^;]+)$", node) if match: rank = rank_codes[match.group(1)] display_name = match.group(2) rank = rank or "no rank" node_id, next_node_id = next_node_id, next_node_id + 1 id_map[node] = node_id names_file.write( "{:d}\t|\t{:s}\t|\t-\t|\tscientific name\t|\n".format( node_id, display_name ) ) nodes_file.write( "{:d}\t|\t{:d}\t|\t{:s}\t|\t-\t|\n".format( node_id, parent_id, rank ) ) children = ( sorted([key for key in child_data[node]]) if node in child_data else [] ) for node in children: bfs_queue.insert(0, [node, node_id]) with open("seqid2taxid.map", "w") as f: for seqid in sorted([key for key in seqid_map], key=int): taxid = id_map[seqid_map[seqid]] f.write("{:s}\t{:d}\n".format(seqid, taxid)) # Build a 16S database from Greengenes data def build_16S_gg(args): args.db = os.path.abspath(args.db) os.makedirs(args.db, exist_ok=True) gg_version = "gg_13_5" remote_directory = "/greengenes_release/" + gg_version os.chdir(args.db) for directory in ["data", "taxonomy", "library"]: os.makedirs(directory, exist_ok=True) os.chdir("data") ftp = FTP(GREENGENES_SERVER) ftp.download(remote_directory, gg_version + ".fasta.gz") decompress_files([gg_version + ".fasta.gz"]) ftp.download(remote_directory, gg_version + "_taxonomy.txt.gz") decompress_files([gg_version + "_taxonomy.txt.gz"]) with open(gg_version + "_taxonomy.txt", "r") as f: build_gg_taxonomy(f) os.chdir(os.path.abspath(os.path.pardir)) move(os.path.join("data", "names.dmp"), "taxonomy") move(os.path.join("data", "nodes.dmp"), "taxonomy") move(os.path.join("data", "seqid2taxid.map"), os.getcwd()) move( os.path.join("data", gg_version + ".fasta"), os.path.join("library", "library.fna"), ) os.chdir("library") os.makedirs("greengenes", exist_ok=True) move("library.fna", "greengenes") os.chdir("greengenes") if not args.no_masking: filename = "library.fna" mask_files([filename], filename + ".masked", args.threads) move(filename + ".masked", filename) os.chdir(args.db) build_kraken2_db(args) # Build a 16S data from RDP data def build_16S_rdp(args): os.makedirs(args.db, exist_ok=True) os.chdir(args.db) for directory in ["data", "taxonomy", "library"]: os.makedirs(directory, exist_ok=True) os.chdir("data") http_download_file( "http://rdp.cme.msu.edu/download/current_Bacteria_unaligned.fa.gz" ) http_download_file( "http://rdp.cme.msu.edu/download/current_Archaea_unaligned.fa.gz" ) decompress_files(glob.glob("*gz")) for filename in glob.glob("current_*_unaligned.fa"): with open(filename, "r") as f: build_rdp_taxonomy(f) os.chdir(os.pardir) move(os.path.join("data", "names.dmp"), "taxonomy") move(os.path.join("data", "nodes.dmp"), "taxonomy") move(os.path.join("data", "seqid2taxid.map"), os.getcwd()) for filename in glob.glob(os.path.join("data", "*.fa")): new_filename = os.path.basename(re.sub(r"\.fa$", ".fna", filename)) shutil.move(filename, os.path.join("library", new_filename)) if not args.no_masking: new_filename = os.path.join("library", new_filename) mask_files( [new_filename], new_filename + ".masked", args.threads ) shutil.move(new_filename + ".masked", new_filename) build_kraken2_db(args) # Reads multi-FASTA input and examines each sequence header. In quiet # mode headers are OK if a taxonomy ID is found (as either the entire # sequence ID or as part of a "kraken:taxid" token), or if something # looking like a GI or accession number is found. In normal mode, the # taxonomy ID will be looked up (if not explicitly specified in the # sequence ID) and reported if it can be found. Output is # tab-delimited lines, with sequence IDs in first column and taxonomy # IDs in second. # Sequence IDs with a kraken:taxid token will use that to assign taxonomy # ID, e.g.: # >gi|32499|ref|NC_021949.2|kraken:taxid|562| # # Sequence IDs that are completely numeric are assumed to be the taxonomy # ID for that sequence. # # Otherwise, an accession number is searched for; if not found, a GI # number is searched for. Failure to find any of the above is a fatal error. # Without `quiet`, a comma-separated file list specified by -A (for both accession # numbers and GI numbers) is examined; failure to find a # taxonomy ID that maps to a provided accession/GI number is non-fatal and # will emit a warning. # # With -q, does not print any output, and will die w/ nonzero exit instead # of warning when unable to find a taxid, accession #, or GI #. # def make_seqid_to_taxid_map( in_file, quiet, accession_map_filenames=False, library_map_filename=None ): target_lists = {} for line in in_file: match = re.match(r">(\S+)", line) if match is None: continue seqid = match.group(1) output = None regexes = [ r"(?:^|\|)kraken:taxid\|(\d+)", r"^\d+$", r"(?:^|\|)([A-Z]+_?[A-Z0-9]+)(?:\||\b|\.)", r"(?:^|\|)gi\|(\d+)", ] match = None index = None for i, regex in enumerate(regexes): match = re.match(regex, seqid) if match: index = i break if index == 0: output = seqid + "\t" + match.group(1) + "\n" elif index == 1: output = seqid + "\t" + seqid + "\n" elif index in [2, 3]: if not quiet: capture = match.group(1) target_lists.setdefault(capture, []) target_lists[capture].insert(0, seqid) else: LOG.error( "Unable to determine taxonomy ID for sequence {:s}\n".format( seqid ) ) sys.exit(1) if output and not quiet: print(output) if quiet: if len(target_lists) == 0: LOG.error("External map required\n") sys.exit(0) if len(target_lists) == 0: sys.exit(0) if not accession_map_filenames and library_map_filename is None: LOG.error( "Found sequence ID without explicit taxonomy ID, but no map used\n" ) sys.exit(1) # Remove targets where we've already handled the mapping if library_map_filename: with open(library_map_filename, "r") as f: for line in f: line = line.strip() seqid, taxid = line.split("\t") if seqid in target_lists: print("{:s}\t{:s}\n".format(seqid, taxid)) del target_lists[seqid] if len(target_lists) == 0: sys.exit(0) for filename in accession_map_filenames: with open(filename, "r") as f: f.readline() for line in f: line = line.strip() accession, with_version, taxid, gi = line.split("\t") if accession in target_lists: target_list = target_lists[accession] del target_lists[accession] for seqid in target_list: print("{:s}\t{:s}".format(seqid, taxid)) if gi != "na" and gi in target_lists: target_list = target_lists[gi] del target_lists[gi] for seqid in target_list: print("{:s}\t{:s}\n".format(seqid, taxid)) def wait_for_files(*fifos): for fifo in fifos: while not os.path.exists(fifo): continue def cleanup_fifos(): LOG.info("Cleaning up fifos and pid files\n") for filename in os.listdir("/tmp"): if re.match(r"^classify_(?:\d+_)?(?:stdin|stdout)$", filename): os.remove(os.path.join("/tmp", filename)) elif filename == "classify.pid": os.remove(os.path.join("/tmp", "classify.pid")) def copy_file_obj(dst): while True: data = sys.stdin.read(8194) if not data: break dst.write(data) dst.close() def check_daemon(): if not os.path.exists("/tmp/classify.pid"): return False with open("/tmp/classify.pid", "r") as in_file: pid = in_file.readline().strip() null = os.open(os.devnull, os.O_WRONLY) alive = subprocess.call( ["ps", pid], stdout=null, stderr=null ) os.close(null) return alive == 0 def message_daemon(message): alive = check_daemon() if not alive: return fd_rd_1 = os.open("/tmp/classify_stdin", os.O_RDONLY | os.O_NONBLOCK) fd_wr_1 = os.open("/tmp/classify_stdin", os.O_WRONLY) fd_rd_2 = os.open("/tmp/classify_stdout", os.O_RDONLY | os.O_NONBLOCK) os.set_blocking(fd_rd_1, True) alive = False try: os.write(fd_wr_1, message) time.sleep(0.1) if os.read(fd_rd_2, 3) == b"OK\n": # put some log here that the daemon is stopped alive = True except BlockingIOError: alive = False os.close(fd_rd_1) os.close(fd_rd_2) os.close(fd_wr_1) def classify_using_daemon(args, argv): alive = check_daemon() if not alive: cleanup_fifos() LOG.info("Starting backgroud classifier process\n") subprocess.call(argv) wait_for_files( "/tmp/classify.pid", "/tmp/classify_stdin", "/tmp/classify_stdout" ) with open("/tmp/classify.pid", "r") as in_file: pid = in_file.readline().strip() LOG.info( "Started background classifier process with PID: {}\n" .format(pid) ) LOG.info("Run k2 clean --stop-daemon to stop it.\n") else: fd = os.open("/tmp/classify_stdin", os.O_RDWR) with os.fdopen(fd, 'w') as out_file: out_file.write(" ".join(argv) + "\n") # the daemon will return the pid of the subprocess # doing the work with open("/tmp/classify_stdout", "r") as in_file: for line in in_file: line = line.strip() if line.startswith("PID"): pid = line.split(":")[1].strip() break else: print(line) proc_in = "/tmp/classify_{}_stdin".format(pid) proc_out = "/tmp/classify_{}_stdout".format(pid) wait_for_files(proc_in, proc_out) thread = None proc_fd = open(proc_in, "w") if len(args.filenames) == 0: thread = threading.Thread( target=copy_file_obj, args=(proc_fd,) ) thread.start() if "output" not in args: with open(proc_out, "r") as in_file: for line in in_file: line = line.strip() print(line) # wait for the classification job to complete if thread: thread.join() with open("/tmp/classify_stdout", "r") as in_file: try: for line in in_file: line = line.strip() if line == "DONE": break except KeyboardInterrupt: os.kill(int(pid), signal.SIGINT) for line in in_file: line = line.strip() if line == "DONE": break # TODO: modify to include the scientific name? def write_raw_sequence_to_file(in_file, out_file, header, taxid=None): fastq = False lines_printed = 0 if chr(header[0]) == '@': fastq = True if taxid: header = header.strip() + (" kraken:taxid|" + str(taxid) + "\n").encode() if out_file: out_file.write(header) lines_printed += 1 header = "" while True: if fastq and lines_printed == 4: break if len(in_file.peek()) == 0: break if not fastq and chr(in_file.peek()[0]) == '>': break line = in_file.readline() if out_file: out_file.write(line) lines_printed += 1 def process_unpaired(input_filenames, classified_out_filename, unclassified_out_filename, classified_headers): if classified_out_filename: classified_out_file = open(classified_out_filename, "wb") if unclassified_out_filename: unclassified_out_file = open(unclassified_out_filename, "wb") for filename in input_filenames: with open(filename, "rb") as in_file: for line in in_file: read_name = line.split(b' ', 1)[0] read_name = read_name[1:].strip() classified = read_name.decode() in classified_headers if classified and classified_out_filename: taxid = classified_headers[read_name.decode()] write_raw_sequence_to_file( in_file, classified_out_file, line, taxid ) elif not classified and unclassified_out_filename: write_raw_sequence_to_file( in_file, unclassified_out_file, line ) else: write_raw_sequence_to_file(in_file, None, line) if classified_out_filename: classified_out_file.close() if unclassified_out_filename: unclassified_out_file.close() def process_paired(input_filenames, classified_out_filename, unclassified_out_filename, classified_headers): if classified_out_filename: classified_out_filename1 =\ classified_out_filename.replace('#', "_1") classified_out_file1 = open(classified_out_filename1, "wb") classified_out_filename2 =\ classified_out_filename.replace('#', "_2") classified_out_file2 = open(classified_out_filename2, "wb") if unclassified_out_filename: unclassified_out_filename1 =\ unclassified_out_filename.replace('#', "_1") unclassified_out_file1 = open(unclassified_out_filename1, "wb") unclassified_out_filename2 =\ unclassified_out_filename.replace('#', "_2") unclassified_out_file2 = open(unclassified_out_filename2, "wb") for filename1, filename2 in zip(input_filenames[::1], input_filenames[::2]): with open(filename1, "rb") as in_file1, open(filename2, "rb") as in_file2: for line1, line2 in zip(in_file1, in_file2): read_name = line1.split(b' ', 1)[0] read_name = read_name[1:].strip() classified = read_name.decode() in classified_headers if classified and classified_out_filename: taxid = classified_headers[read_name.decode()] write_raw_sequence_to_file( in_file1, classified_out_file1, line1, taxid ) write_raw_sequence_to_file( in_file2, classified_out_file2, line2, taxid ) elif not classified and unclassified_out_filename: write_raw_sequence_to_file( in_file1, unclassified_out_file1, line1 ) write_raw_sequence_to_file( in_file2, unclassified_out_file2, line2 ) else: write_raw_sequence_to_file( in_file1, None, line1 ) write_raw_sequence_to_file( in_file2, None, line2 ) if classified_out_filename: classified_out_file1.close() classified_out_file2.close() if unclassified_out_filename: unclassified_out_file1.close() unclassified_out_file2.close() def process_interleaved( input_filenames, classified_out_filename, unclassified_out_filename, classified_headers): if classified_out_filename: classified_out_file = open(classified_out_filename, "wb") if unclassified_out_filename: unclassified_out_file = open(unclassified_out_filename, "wb") for filename in input_filenames: with open(filename, "rb") as in_file: for line in in_file: line2 = in_file.readline() read_name = line.split(b' ', 1)[0] read_name = read_name[1:].strip() classified = read_name.decode() in classified_headers if classified and classified_out_filename: taxid = classified_headers[read_name.decode()] write_raw_sequence_to_file( in_file, classified_out_file, line, taxid ) write_raw_sequence_to_file( in_file, classified_out_file, line2, taxid ) elif not classified and unclassified_out_filename: write_raw_sequence_to_file( in_file, unclassified_out_file, line ) write_raw_sequence_to_file( in_file, unclassified_out_file, line2 ) else: write_raw_sequence_to_file(in_file, None, line) write_raw_sequence_to_file(in_file, None, line2) if classified_out_filename: classified_out_file.close() if unclassified_out_filename: unclassified_out_file.close() def write_fasta_sequences( args, input_filenames, classified_out_filename, unclassified_out_filename, classified_headers): if "paired" in args: process_paired( input_filenames, classified_out_filename, unclassified_out_filename, classified_headers ) elif "interleaved" in args: process_interleaved( input_filenames, classified_out_filename, unclassified_out_filename, classified_headers ) else: process_unpaired( input_filenames, classified_out_filename, unclassified_out_filename, classified_headers ) class TaxonomyStruct(ctypes.Structure): pass class Taxonomy: def __init__(self, dll_pathname): self.dll = ctypes.CDLL(dll_pathname) self.dll.init_taxonomy.restype = ctypes.POINTER(TaxonomyStruct) self.dll.get_lca.restype = ctypes.c_uint64 self.dll.get_internal_taxid.restype = ctypes.c_uint64 self.dll.is_ancestor_of.restype = ctypes.c_bool self.dll.get_rank.restype = ctypes.c_char_p self.dll.get_child_count.restype = ctypes.c_uint64 self.dll.taxid_to_name.restype = ctypes.c_char_p # self.dll.get_child_taxids.restype = ctypes.c_ def generate_taxonomy(self, names, nodes, seqid2taxid, taxonomy_pathname): names = ctypes.c_char_p(names.encode()) nodes = ctypes.c_char_p(nodes.encode()) seqid2taxid = ctypes.c_char_p(seqid2taxid.encode()) taxonomy_pathname = ctypes.c_char_p(taxonomy_pathname.encode()) self.dll.generate_taxonomy( names, nodes, seqid2taxid, taxonomy_pathname ) def load_taxonomy(self, taxonomy_pathname): tax_file = ctypes.c_char_p(taxonomy_pathname.encode()) self.taxonomy = self.dll.init_taxonomy(tax_file) def get_lca(self, taxid1, taxid2): taxid1 = int(taxid1) taxid2 = int(taxid2) return self.dll.get_lca(self.taxonomy, taxid1, taxid2) def get_internal_taxid(self, taxid): return self.dll.get_internal_taxid(self.taxonomy, taxid) def is_ancestor_of(self, parent, child): return self.dll.is_ancestor_of(self.taxonomy, parent, child) def destroy_taxonomy(self): self.dll.destroy_taxonomy(self.taxonomy) def get_rank(self, taxid): return self.dll.get_rank(self.taxonomy, taxid) def taxid_to_name(self, taxid): return self.dll.taxid_to_name(self.taxonomy, taxid) def get_child_count(self, taxid): return self.dll.get_child_count(self.taxonomy, taxid) def get_parent_id(self, taxid): return self.dll.get_parent_id(self.taxonomy, taxid) def get_child_taxids(self, taxid): num_children = self.get_child_count(taxid) child_taxids = (ctypes.c_uint64 * num_children)(*([0] * num_children)) self.dll.get_child_taxids( self.taxonomy, taxid, ctypes.byref(child_taxids), num_children ) return child_taxids class ReadCounts: def __init__(self): self.n_kmers = 0 self.n_reads = 0 def get_read_count(self): return self.n_reads def get_kmer_count(self): return self.n_kmers def increment_read_count(self): self.n_reads += 1 def __iadd__(self, other): self.n_kmers += other.n_kmers self.n_reads += other.n_reads return self class TaxonCounters: def __init__(self): self.counter = collections.defaultdict(ReadCounts) def __getitem__(self, taxid): return self.counter[taxid] def __setitem__(self, taxid, c): self.counter[taxid] = c def items(self): return self.counter.items() def keys(self): return self.counter.keys() def get_clade_counters(taxonomy, call_counters): clade_counters = TaxonCounters() for k, v in call_counters.items(): while k != 0: clade_counters[k] += v k = taxonomy.get_parent_id(k) return clade_counters def print_kraken_style_report(out_file, report_kmer_data, total_seqs, clade_counter, taxon_counter, rank_string, taxid, scientific_name, depth): read_count = clade_counter.get_read_count() percentage = 100.0 * read_count / total_seqs out_string = "{:6.2f}".format(percentage) +\ "\t" + str(clade_counter.get_read_count()) +\ "\t" + str(taxon_counter.get_read_count()) +\ "\t" + rank_string + "\t" + str(taxid) +\ "\t" + " " * depth + scientific_name + "\n" out_file.write(out_string) def kraken_report_dfs(out_file, taxid, report_zeros, report_kmer_data, taxonomy, clade_counters, call_counters, total_seqs, rank_code, rank_depth, depth): clade_counter = clade_counters[taxid] call_counter = call_counters[taxid] if not report_zeros and clade_counter.get_read_count() == 0: return rank = taxonomy.get_rank(taxid).decode() if rank == "superkingdom": rank_code = 'D' rank_depth = 0 elif rank == "kingdom": rank_code = 'K' rank_depth = 0 elif rank == "phylum": rank_code = 'P' rank_depth = 0 elif rank == "class": rank_code = 'C' rank_depth = 0 elif rank == "order": rank_code = 'O' rank_depth = 0 elif rank == "family": rank_code = 'F' rank_depth = 0 elif rank == "genus": rank_code = 'G' rank_depth = 0 elif rank == "species": rank_code = 'S' rank_depth = 0 else: rank_depth += 1 rank_string = rank_code if rank_depth > 0: rank_string += str(rank_depth) scientific_name = taxonomy.taxid_to_name(taxid).decode() print_kraken_style_report( out_file, False, total_seqs, clade_counter, call_counter, rank_string, taxid, scientific_name, depth ) children = sorted( taxonomy.get_child_taxids(taxid), key=lambda t: clade_counters[t].get_read_count() ) for child_taxid in children: kraken_report_dfs( out_file, child_taxid, report_zeros, report_kmer_data, taxonomy, clade_counters, call_counters, total_seqs, rank_code, rank_depth, depth + 1 ) def report_kraken_style(filename, report_zeros, report_kmer_data, taxonomy, call_counters, total_seqs): clade_counters = get_clade_counters(taxonomy, call_counters) total_unclassified = call_counters[0].get_read_count() rank_code = "R" with open(filename, "w") as out_file: if total_unclassified > 0: print_kraken_style_report( out_file, False, total_seqs, call_counters[0], call_counters[0], "U", 0, "unclassified", 0 ) kraken_report_dfs( out_file, 1, report_zeros, False, taxonomy, clade_counters, call_counters, total_seqs, rank_code, 0, 0 ) # Taken from ResolveTree function in classify.cc def resolve_taxa_tree(hit_counts, taxonomy, total_kmers, args): max_taxid = 0 max_score = 0 required_score = math.ceil(args.confidence * total_kmers) for taxid in hit_counts.keys(): score = 0 for taxid2, counts in hit_counts.items(): if taxonomy.is_ancestor_of(taxid2, taxid): score += counts if score > max_score: max_score, max_taxid = score, taxid elif score == max_score: max_taxid = taxonomy.get_lca(max_taxid, taxid) max_score = hit_counts[max_taxid] while max_taxid != 0 and max_score < required_score: max_score = 0 for taxid, counts in hit_counts.items(): if taxonomy.is_ancestor_of(max_taxid, taxid): max_score += counts if max_score >= required_score: return max_taxid else: max_taxid = taxonomy.get_parent_id(max_taxid) return max_taxid def parse_taxid_counts(string, counts): index = 0 counts_len = string.count(':') * 2 if counts is None: counts = array.array('I', range(0, counts_len)) _, counts_capacity = counts.buffer_info() if counts_capacity < counts_len: counts.extend(range(0, counts_len - counts_capacity)) for entry in string.split(): taxid, count = entry.split(':') if taxid == 'A': taxid, count = AMBIGUOUS_TAXID, int(count) # separator for paired reads, set the taxid and count # to an invalid value elif taxid == '|': taxid, count = 0, 0 else: taxid, count = int(taxid), int(count) counts[index] = taxid index += 1 counts[index] = count index += 1 return counts, counts_len def get_lca(taxonomy, cache, taxid1, taxid2): taxid1, taxid2 = (taxid1, taxid2) if taxid1 < taxid2 else (taxid2, taxid1) if (taxid1, taxid2) not in cache: cache[(taxid1, taxid2)] = taxonomy.get_lca(taxid1, taxid2) return cache[(taxid1, taxid2)] def merge_counts( taxonomy, lca_cache, counts1, counts1_len, counts2, counts2_len): index1 = 0 index2 = 0 total_minimizers = 0 counts_str = "" counts_map = collections.defaultdict(int) final_counts = [] while True: taxid1, count1 = counts1[index1], counts1[index1 + 1] taxid2, count2 = counts2[index2], counts2[index2 + 1] final_count = 0 # final_taxid = taxonomy.get_lca(taxid1, taxid2) if taxid1 == AMBIGUOUS_TAXID or taxid2 == AMBIGUOUS_TAXID: final_taxid = AMBIGUOUS_TAXID else: final_taxid = get_lca(taxonomy, lca_cache, taxid1, taxid2) if count1 < count2: counts2[index2 + 1] -= count1 index1 += 2 final_count = count1 if counts2[index2 + 1] == 0: index2 += 2 elif count2 < count1: counts1[index1 + 1] -= count2 index2 += 2 final_count = count2 if counts1[index1 + 1] == 0: index1 += 2 else: index1 += 2 index2 += 2 final_count = count1 total_minimizers += final_count counts_map[final_taxid] += final_count final_counts.append((final_taxid, final_count)) if index1 >= counts1_len: break def taxid_to_string(taxid): if taxid == AMBIGUOUS_TAXID: return "A" else: return str(taxid) counts_str = io.StringIO() previous_taxid, previous_count = final_counts[0] if len(final_counts) > 1: for taxid, count in final_counts[1:]: if taxid == 0 and count == 0: counts_str.write(taxid_to_string(previous_taxid)) counts_str.write(":") counts_str.write(str(previous_count)) counts_str.write(" ") counts_str.write("|") counts_str.write(":") counts_str.write("|") counts_str.write(" ") previous_taxid = taxid previous_count = count if previous_taxid == taxid: previous_count += count else: # counts_str += str(previous_taxid) + ":" + str(previous_count) # counts_str += " " counts_str.write(taxid_to_string(previous_taxid)) counts_str.write(":") counts_str.write(str(previous_count)) counts_str.write(" ") previous_taxid = taxid previous_count = count # counts_str += str(previous_taxid) + ":" + str(previous_count) counts_str.write(taxid_to_string(previous_taxid)) counts_str.write(":") counts_str.write(str(previous_count)) return (counts_str.getvalue(), counts_map, total_minimizers) def merge_classification_output( taxonomy, in_filename1, in_filename2, out_filename, use_names, args, final=False): call_counters = TaxonCounters() total_seqs = 0 counts_array1 = None counts_array2 = None lca_cache = {} with open(in_filename1) as file1, open(in_filename2) as file2: with open(out_filename, "w") as out_file: for (line1, line2) in zip(file1, file2): (status1, name1, taxid1, len1, counts1) =\ line1.strip().split('\t', 5) (status2, name2, taxid2, len2, counts2) =\ line2.strip().split('\t', 5) if status1 == "C" and status2 == "C": status = "C" counts_array1, counts_len1 =\ parse_taxid_counts(counts1, counts_array1) counts_array2, counts_len2 =\ parse_taxid_counts(counts2, counts_array2) counts, counts_map, total_minimizers =\ merge_counts( taxonomy, lca_cache, counts_array1, counts_len1, counts_array2, counts_len2 ) taxid = resolve_taxa_tree( counts_map, taxonomy, total_minimizers, args ) elif status1 == "C" and status2 == "U": status = "C" taxid = taxid1 counts = counts1 elif status1 == "U" and status2 == "C": status = "C" taxid = taxid2 counts = counts2 else: status = "U" taxid = 0 counts = "0:0" if final: call_counters[int(taxid)].increment_read_count() total_seqs += 1 if use_names: scientific_name = taxonomy.taxid_to_name(int(taxid)) records = "\t".join( [status, name1, scientific_name.decode(), "(taxid " + str(taxid) + ")", len1] ) else: records = "\t".join( [status, name1, str(taxid), len1] ) out_file.write(records) out_file.write("\t") out_file.write(counts) out_file.write("\n") return (call_counters, total_seqs) if final else (None, total_seqs) def merge_classification_output2( taxonomy_pathname, lines, job_number, use_names, args, save_seq_names, final): taxonomy_dll_pathname = find_kraken2_binary("libtax.so") taxonomy = Taxonomy(taxonomy_dll_pathname) taxonomy.load_taxonomy(taxonomy_pathname) call_counters = TaxonCounters() total_seqs = 0 counts_array1 = None classified_headers = {} counts_array2 = None lca_cache = {} out_filename = tempfile.mktemp( prefix="k2_job" + str(job_number) + "_" ) with open(out_filename, "w") as out_file: for (left, right) in lines: (status1, seq_name1, taxid1, len1, counts1) =\ left.strip().split('\t', 5) (status2, seq_name2, taxid2, len2, counts2) =\ right.strip().split('\t', 5) if status1 == "C" and status2 == "C": status = "C" counts_array1, counts_len1 =\ parse_taxid_counts(counts1, counts_array1) counts_array2, counts_len2 =\ parse_taxid_counts(counts2, counts_array2) counts, counts_map, total_minimizers =\ merge_counts( taxonomy, lca_cache, counts_array1, counts_len1, counts_array2, counts_len2 ) taxid = resolve_taxa_tree( counts_map, taxonomy, total_minimizers, args ) elif status1 == "C" and status2 == "U": status = "C" taxid = taxid1 counts = counts1 elif status1 == "U" and status2 == "C": status = "C" taxid = taxid2 counts = counts2 else: status = "U" taxid = 0 counts = "0:0" if final: call_counters[int(taxid)].increment_read_count() if save_seq_names: if status == "C": classified_headers[seq_name1] = taxid total_seqs += 1 if use_names: scientific_name = taxonomy.taxid_to_name(int(taxid)) records = "\t".join( [status, seq_name1, scientific_name.decode() + " " + "(taxid " + str(taxid) + ")", len1] ) else: records = "\t".join( [status, seq_name1, str(taxid), len1] ) out_file.write(records) out_file.write("\t") out_file.write(counts) out_file.write("\n") if final: return ( out_filename, call_counters, total_seqs, classified_headers ) else: return (out_filename, None, None, None) def merge_classification_output_parallel( pool, taxonomy_pathname, in_filename1, in_filename2, out_filename, use_names, args, save_seq_names, final): filenames = [] call_counters = TaxonCounters() total_seqs = 0 with open(in_filename1) as file1, open(in_filename2) as file2: input = list(zip(file1.readlines(), file2.readlines())) input_len = len(input) partition_ranges = list(range(0, input_len, int(input_len / args.threads))) partition_ranges[-1] = input_len job_number = 0 futures = [] for start, end in zip(partition_ranges, partition_ranges[1:]): future = pool.submit( merge_classification_output2, taxonomy_pathname, input[start:end], job_number, use_names, args, save_seq_names, final ) futures.append(future) job_number += 1 done, not_done = concurrent.futures.wait(futures) classified_headers = [] for future in done: filename, counters, total, classified_set =\ future.result() filenames.append(filename) if final: classified_headers.append(classified_set) total_seqs += total for key, value in counters.items(): call_counters[key] += value with open(out_filename, "w") as out_file: for filename in sorted(filenames): with open(filename, "r") as in_file: shutil.copyfileobj(in_file, out_file) os.remove(filename) if final: classified_headers = collections.ChainMap(*classified_headers) return (call_counters, total_seqs, classified_headers) else: return (None, total_seqs, None) def classify_multi_dbs(args): dbs = args.db use_names = False report_filename = None report_zeros = False classified_out_filename = None unclassified_out_filename = None if "use_mpa_style" in args: LOG.error( "--use-mpa-style not supported when using multiple dbs\n" ) sys.exit(1) if "report_minimizer_data" in args: LOG.error( "--report-minimizer-data not supported when using " "multiple dbs\n" ) sys.exit(1) if "report" in args: report_filename, args.report = (args.report, None) output = args.output if "output" in args else None if "report_zeros" in args: report_zeros = True args.report_zeros = None if "classified_out" in args: classified_out_filename = args.classified_out args.classified_out = None if "unclassified_out" in args: unclassified_out_filename = args.unclassified_out args.unclassified_out = None if "use_names" in args: use_names = True args.use_names = None tmp_filenames = [] LOG.info("Creating merged taxonomy\n") seqid2taxid_maps = [] for db in dbs: seqid2taxid_map = os.path.join(db, "seqid2taxid.map") if not os.path.exists(seqid2taxid_map): LOG.error( "Unable to find seqid2taxid.map for database {}\n".format(db) ) LOG.error( "seqid2taxid.map files are needed to create a merged taxonomy" " for mulit-database classification\n" ) seqid2taxid_maps.append(seqid2taxid_map) seqid2taxid_maps = sorted(seqid2taxid_maps) with tempfile.NamedTemporaryFile( prefix="k2_seqid2taxid", suffix=".map", delete=False) as out_file: seqid2taxid_map_pathname = out_file.name for seqid2taxid_map in seqid2taxid_maps: with open(seqid2taxid_map, "rb") as in_file: shutil.copyfileobj(in_file, out_file) taxonomy_pathname = os.path.join(dbs[0], "taxonomy") names_pathname = os.path.join(taxonomy_pathname, "names.dmp") nodes_pathname = os.path.join(taxonomy_pathname, "nodes.dmp") taxonomy_dll_pathname = find_kraken2_binary("libtax.so") taxonomy = Taxonomy(taxonomy_dll_pathname) taxonomy_pathname = tempfile.mktemp(prefix="k2_taxo", suffix=".k2d") taxonomy.generate_taxonomy( names_pathname, nodes_pathname, seqid2taxid_map_pathname, taxonomy_pathname ) LOG.info("Finished creating and loading merged taxonomy\n") for db in dbs: args.db = db pathname = tempfile.mktemp(prefix="k2_") tmp_filenames.append(pathname) args.output = pathname LOG.info( "Running classification job for database {}\n".format(db) ) classify(args) LOG.info( "Finished running classification job for database {}\n".format(db) ) LOG.info("Merging output files\n") out_filename = tempfile.mktemp(prefix="k2_") tmp_filenames_copy = tmp_filenames.copy() tmp_filenames.append(out_filename) final = False pool = concurrent.futures.ProcessPoolExecutor(max_workers=args.threads) progress = ProgressBar(len(tmp_filenames), 0) while True: LOG.debug("Merge progress: {}\r".format(progress.get_bar())) if len(tmp_filenames_copy) == 2: final = True filename1, filename2 = tmp_filenames_copy.pop(), tmp_filenames_copy.pop() # LOG.info("merging classification output\n") save_seq_names = final and (classified_out_filename is not None or unclassified_out_filename is not None) call_counters, total_seqs, classified_headers =\ merge_classification_output_parallel( pool, taxonomy_pathname, filename1, filename2, out_filename, use_names and final, args, save_seq_names, final ) progress.progress(1, relative=True) # LOG.info("finished merging classification output\n") LOG.debug("Merge progress: {}\r".format(progress.get_bar())) if len(tmp_filenames_copy) == 0: break tmp_filenames_copy.insert(0, out_filename) out_filename = filename1 pool.shutdown() LOG.info("Finished merging output files\n") if output: shutil.move(out_filename, os.path.abspath(output)) if out_filename in tmp_filenames: tmp_filenames.remove(out_filename) else: with open(out_filename, "r") as out_file: shutil.copyfileobj(out_file, sys.stdout) for tmp_filename in tmp_filenames: os.remove(tmp_filename) if classified_out_filename or unclassified_out_filename: LOG.info("Writing (un)classified sequences to file\n") write_fasta_sequences( args, args.filenames, classified_out_filename, unclassified_out_filename, classified_headers ) LOG.info("Finished writing (un)classified sequences to file\n") if report_filename is not None: LOG.info("Generating report file\n") taxonomy.load_taxonomy(taxonomy_pathname) report_kraken_style( report_filename, report_zeros, False, taxonomy, call_counters, total_seqs ) LOG.info("Finished generating report file\n") os.remove(taxonomy_pathname) os.remove(seqid2taxid_map_pathname) def classify(args): classify_bin = find_kraken2_binary("classify") database_path = find_database(args.db) if database_path is None: LOG.error("{:s} is not a valid database... exiting\n".format(args.db)) sys.exit(1) if "paired" in args and len(args.filenames) % 2 != 0: LOG.error("--paired requires an even number of file names\n") sys.exit(1) if args.confidence < 0 or args.confidence > 1: LOG.error( "--confidence, {:f}, must be between 0 and 1 inclusive\n".format( args.confidence ) ) sys.exit(1) argv = [ classify_bin, "-H", os.path.join(database_path, "hash.k2d"), "-t", os.path.join(database_path, "taxo.k2d"), "-o", os.path.join(database_path, "opts.k2d"), ] wrapper_args_to_binary_args(args, argv, get_binary_options(classify_bin)) if any([is_compressed(filename) for filename in args.filenames]): with tempfile.TemporaryDirectory() as temp_dir_name: fifo1_pathname = os.path.join(temp_dir_name, "fifo1") fifo2_pathname = None try: os.mkfifo(fifo1_pathname, 0o600) except OSError: LOG.error( "Unable to create FIFO for processing compressed files\n" ) sys.exit(1) if "-P" in argv: fifo2_pathname = os.path.join(temp_dir_name, "fifo2") try: os.mkfifo(fifo2_pathname, 0o600) except OSError: LOG.error( "Unable to create FIFO for processing compressed files\n" ) sys.exit(1) argv.extend([fifo1_pathname, fifo2_pathname]) else: argv.append(fifo1_pathname) if args.use_daemon: thread = threading.Thread( target=classify_using_daemon, args=(args, argv) ) else: thread = threading.Thread(target=subprocess.call, args=(argv,)) thread.start() if "-P" in argv: writer_thread1 = threading.Thread( target=write_to_fifo, args=(args.filenames[0::2], fifo1_pathname) ) writer_thread2 = threading.Thread( target=write_to_fifo, args=(args.filenames[1::2], fifo2_pathname) ) writer_thread1.start() writer_thread2.start() writer_thread1.join() writer_thread2.join() else: write_to_fifo(args.filenames, fifo1_pathname) thread.join() else: for i, filename in enumerate(args.filenames): args.filenames[i] = os.path.abspath(filename) argv.extend(args.filenames) if args.use_daemon: classify_using_daemon(args, argv) else: subprocess.call(argv) def inspect_db(args): database_pathname = find_database(args.db) if not database_pathname: LOG.error("{:s} database does not exist\n".format(args.db)) sys.exit(1) for database_file in ["taxo.k2d", "hash.k2d", "opts.k2d"]: if not os.path.isfile(os.path.join(database_pathname, database_file)): LOG.error("{:s} does not exist\n".format(database_file)) dump_table_bin = find_kraken2_binary("dump_table") argv = [ dump_table_bin, "-H", os.path.join(database_pathname, "hash.k2d"), "-t", os.path.join(database_pathname, "taxo.k2d"), "-o", os.path.join(database_pathname, "opts.k2d"), ] # dump_table does not save the table header to file. # This is a workaround helps enables us to capture # the entire output. output_filename, args.output = args.output, None wrapper_args_to_binary_args(args, argv, get_binary_options(dump_table_bin)) process = subprocess.Popen( argv, stdout=subprocess.PIPE ) if output_filename == "-": shutil.copyfileobj(process.stdout, sys.stdout.buffer) else: with open(output_filename, "wb") as fout: shutil.copyfileobj(process.stdout, fout) process.wait() def format_bytes(size): current_suffix = "B" for suffix in ["kB", "MB", "GB", "TB", "PB", "EB"]: if size >= 1024: current_suffix = suffix size /= 1024 else: break return "{:.2f}{:s}".format(size, current_suffix) def clean_up(filenames): LOG.info("Removing the following files: {}\n".format(filenames)) # walk the directory tree to get the size of the individual files # sum them up to get the usage stat space_freed = format_bytes(remove_files(filenames)) LOG.info( "Cleaned up {} of space\n".format(space_freed) ) def range_parser(input): if input == "all": input = ".." volumes = [] regex = re.compile(r"(\d+)?(?:\.{2,3}|\-|:)(\d+)?") for volume in input.replace(' ', '').split(','): if volume.isdecimal(): volumes.append(int(volume)) else: match = regex.match(volume) if not match: raise argparse.ArgumentTypeError(input) start, end = match.group(1), match.group(2) if not start: start = 0 if not end: end = 1000 start, end = int(start), int(end) + 1 expanded_range = list(range(start, end)) volumes.extend(expanded_range) volumes = set(volumes) return volumes def clean_db(args): if args.stop_daemon: message_daemon(b"STOP\n") LOG.info("Stopped background classifier process\n") cleanup_fifos() else: os.chdir(args.db) if args.pattern: clean_up(glob.glob(args.pattern, recursive=False)) else: clean_up( [ "data", "library", "taxonomy", "seqid2taxid.map", "prelim_map.txt", ] ) def make_build_parser(subparsers): parser = subparsers.add_parser( "build", help="Build a database from library\ (requires taxonomy which can be downloading\ via download-taxonomy subcommand, and at least one library\ which can be added via the download-library or\ add-to-library subcommands).", ) parser.add_argument( "--db", type=str, metavar="PATHNAME", required=True, help="Pathname to database folder where building will take place.", ) group = parser.add_argument_group("special") mutex_group = group.add_mutually_exclusive_group() mutex_group.add_argument( "--standard", action="store_true", help="Make standard database which includes: archaea,\ bacteria, human, plasmid, UniVec_Core, and viral." ) mutex_group.add_argument( "--special", type=str, choices=["greengenes", "rdp", "silva", "gtdb"], help="Build special database. RDP is currently unavailable\ as URLs no longer work.", ) group.add_argument( "--gtdb-files", type=str, nargs="+", help="A list of files or regex matching the files needed to build\ the special database." ) group.add_argument( "--gtdb-use-ncbi-taxonomy", action="store_true", help="Use NCBI tax IDs and taxonomy tree when building GTDB database" ) group.add_argument( "--gtdb-server", type=str, default=GTDB_SERVER, help="The GTDB server to use (default: {})".format(GTDB_SERVER) ) group.add_argument( "--no-masking", action="store_true", help="Avoid masking low-complexity sequences prior to\ building database.", ) group.add_argument( "--masker-threads", type=int, default=4, metavar="K2MASK_THREADS", help="Number of threads used by k2mask during masking\ process (default: 4)" ) parser.add_argument( "--kmer-len", type=int, metavar="INT", help="K-mer length in bp/aa" ) parser.add_argument( "--minimizer-len", type=int, metavar="INT", help="Minimizer length in bp/aa", ) parser.add_argument( "--minimizer-spaces", type=int, metavar="INT", help="Number of characters in minimizer that are\ ignored in comparisons", ) parser.add_argument( "--threads", type=int, metavar="INT", default=os.environ.get("KRAKEN2_NUM_THREADS") or 1, help="Number of threads", ) parser.add_argument( "--load-factor", type=float, metavar="FLOAT (0,1]", default=0.7, help="Proportion of the hash table to be populated (default: 0.7)", ) parser.add_argument( "--fast-build", action="store_true", help="Do not require database to be deterministically\ built when using multiple threads. This is faster, but\ does introduce variability in minimizer/LCA pairs.", ) parser.add_argument( "--max-db-size", # type=int, metavar="SIZE", help="Maximum number of bytes for Kraken 2 hash table;\ if the estimator determines more would normally be\ needed, the reference library will be downsampled to fit", ) parser.add_argument( "--skip-maps", action="store_true", help="Avoids downloading accession number to taxid maps", ) parser.add_argument( "--protein", action="store_true", help="Build a protein database for translated search", ) parser.add_argument( "--block-size", type=int, metavar="INT", default=16384, help="Read block size (default: 16384)", ) parser.add_argument( "--sub-block-size", type=int, metavar="INT", default=0, help="Read subblock size", ) parser.add_argument( "--minimum-bits-for-taxid", type=int, metavar="INT", default=0, help="Bit storage requested for taxid", ) parser.add_argument( "--log", type=str, metavar="FILENAME", default=None, help="Specify a log file (default: stderr)", ) def make_download_taxonomy_parser(subparsers): parser = subparsers.add_parser( "download-taxonomy", help="Download NCBI taxonomic information" ) parser.add_argument( "--db", type=str, metavar="PATHNAME", required=True, help="Pathname to Kraken2 database", ) # parser.add_argument( # "--source", # type=str, # choices=[ # "GTDB", # "NCBI" # ], # default="NCBI", # help="From which database should the files be downloaded" # ) parser.add_argument( "--protein", action="store_true", help="Files being added are for a protein database", ) parser.add_argument( "--skip-maps", action="store_true", help="Avoids downloading accession number to taxid maps", ) parser.add_argument( "--log", type=str, metavar="FILENAME", default=None, help="Specify a log filename (default: stderr)", ) def make_download_library_parser(subparsers): parser = subparsers.add_parser( "download-library", aliases=["download"], help="Download and build a special database" ) parser.register("type", "range", range_parser) parser.add_argument( "--db", type=str, metavar="PATHNAME", required=True, help="Pathname to Kraken2 database", ) parser.add_argument( "--library", "--taxid", "--project", "--accession", type=str, dest="library", required=True, # choices=[ # "archaea", # "bacteria", # "plasmid", # "plastid", # "viral", # "human", # "invertebrate", # "fungi", # "plant", # "protozoa", # "vertebrate_other" # "vertebrate_mammalian", # "mitochondrion", # "nr", # "nt", # "UniVec", # "UniVec_Core", # ], help="Name of library to download", ) parser.add_argument( "--assembly-source", type=str, required=False, choices=["refseq", "genbank", "all"], default="refseq", help="Download RefSeq (GCF_) or GenBank (GCA_) genome assemblies\ or both (default RefSeq)", ) parser.add_argument( "--resume", action="store_true", help="Resume fetching the files needed for a library, skipping files\ that have already been downloaded", ) parser.add_argument( "--assembly-levels", type=str, nargs="+", choices=["chromosome", "complete_genome", "scaffold", "contig"], default=["chromosome", "complete_genome"], help="Only return genome assemblies that have one of the specified\ assembly levels (default chromosome and complete genome)" ) parser.add_argument( "--has-annotation", action="store_true", help="Return only annotated genome assemblies (default: false)" ) parser.add_argument( "--blast-volumes", type="range", default="all", help="A comma separated list of the blast volume numbers to download.\ Ranges are also accepted in the forms start..end, start-end, start:end,\ ranges are inclusive (default: all volumes)" ) parser.add_argument( "--protein", action="store_true", help="Files being added are for a protein database", ) parser.add_argument( "--log", type=str, metavar="FILENAME", default=None, help="Specify a log filename (default: stderr)", ) parser.add_argument( "--threads", type=int, metavar="THREADS", default=1, help="The number of threads/processes k2 uses when downloading\ and processing library files.", ) masking_parser = parser.add_mutually_exclusive_group() masking_parser.add_argument( "--no-masking", action="store_true", help="Avoid asking low-complexity sequences prior to\ building; masking requires k2mask or segmasker to be\ installed", ) masking_parser.add_argument( "--masker-threads", type=int, default=4, metavar="K2MASK_THREADS", help="Number of threads used by k2mask during masking\ process (default: 4)" ) def make_add_to_library_parser(subparsers): parser = subparsers.add_parser( "add-to-library", help="Add file(s) to library" ) parser.add_argument( "--db", type=str, metavar="PATHNAME", required=True, help="Pathname to Kraken2 database", ) parser.add_argument( "--threads", type=int, metavar="THREADS", default=1, help="The number of threads/processes k2 uses when\ adding library files." ) parser.add_argument( "--file", "--files", type=str, nargs="+", required=True, dest="files", help="""Pathname or patterns of file(s) to be added to library. Supported pattern are as follows: ? - A question-mark is a pattern that shall match any character. * - An asterisk is a pattern that shall match multiple characters. [ - The open bracket shall introduce a pattern bracket expression. ** - will match any files and zero or more directories, subdirectories and symbolic links to directories. """, ) parser.add_argument( "--protein", action="store_true", help="Files being added are for a protein database", ) parser.add_argument( "--log", type=str, metavar="FILENAME", default=None, help="Specify a log filename (default: stderr)", ) # parser.add_argument( # "--skip-md5", # action="store_true", # help="K2 will by default perform an MD5 check to determine whether\ # a file has already been added. This option will allow the user\ # to skip this and instead simply compare filenames." # ) masking_parser = parser.add_mutually_exclusive_group() masking_parser.add_argument( "--no-masking", action="store_true", help="Avoid asking low-complexity sequences prior to\ building; masking requires k2mask or segmasker to be\ installed", ) masking_parser.add_argument( "--masker-threads", type=int, metavar="K2MASK_THREADS", default=4, help="Number of threads used by k2mask during masking process\ (default: 4)" ) def make_classify_parser(subparsers): parser = subparsers.add_parser( "classify", help="Classify a set of sequences" ) parser.add_argument( "--db", type=lambda x: x.split(","), metavar="PATHNAME", # nargs="+", required=True, help="Pathname to Kraken2 database(s).\ Multiple databases are specified as a comma-\ separated list with no spaces.", ) parser.add_argument( "--threads", type=int, metavar="INT", default=os.environ.get("KRAKEN2_NUM_THREADS") or 1, help="Number of threads", ) parser.add_argument( "--use-daemon", action="store_true", help="Spawn a background process that keeps any loaded indexes\ in memory. Subsequent invokations of classify with this option will\ skip the index loading process and immediately start classifying\ reads. If a new index is specified that index will also be persisted.\ Use k2 clean --stop-daemon to stop the background process." ) parser.add_argument( "--quick", action="store_true", default=argparse.SUPPRESS, help="Quick operation (use first hit or hits)", ) parser.add_argument( "--unclassified-out", type=str, default=argparse.SUPPRESS, metavar="FILENAME", help="Print unclassified sequences to filename", ) parser.add_argument( "--classified-out", type=str, metavar="FILENAME", default=argparse.SUPPRESS, help="Print classified sequences to filename", ) parser.add_argument( "--output", type=str, metavar="FILENAME", default=argparse.SUPPRESS, help='Print output to file (default: stdout) "-" will \ suppress normal output', ) parser.add_argument( "--confidence", type=float, default=0.0, help="confidence score threshold (default: 0.0); must be in [0,1]", ) parser.add_argument( "--minimum-base-quality", type=int, metavar="INT", default=0, help="Minimum base quality used in classification", ) parser.add_argument( "--report", type=str, default=argparse.SUPPRESS, help="Print a report with aggregate counts/clade to file", ) parser.add_argument( "--use-mpa-style", action="store_true", default=argparse.SUPPRESS, help="With --report, format report output like Kraken 1's\ kraken-mpa-report", ) parser.add_argument( "--report-zero-counts", action="store_true", default=argparse.SUPPRESS, help="With --report, report counts for ALL taxa, even if\ counts are zero", ) parser.add_argument( "--report-minimizer-data", action="store_true", default=argparse.SUPPRESS, help="With --report, report minimizer and distinct minimizer\ count information in addition to normal Kraken report", ) parser.add_argument( "--memory-mapping", action="store_true", default=argparse.SUPPRESS, help="Avoids loading entire database into RAM", ) paired_group = parser.add_mutually_exclusive_group() paired_group.add_argument( "--paired", action="store_true", default=argparse.SUPPRESS, help="The filenames provided have paired-end reads", ) paired_group.add_argument( "--interleaved", action="store_true", default=argparse.SUPPRESS, help="The filenames provided have paired-end reads", ) parser.add_argument( "--use-names", action="store_true", default=argparse.SUPPRESS, help="Print scientific names instead of just taxids", ) parser.add_argument( "--minimum-hit-groups", type=int, metavar="INT", default=2, help="Minimum number of hit groups (overlapping k-mers\ sharing the same minimizer) needed to make a call\ (default 2)", ) parser.add_argument( "--log", type=str, metavar="FILENAME", default=None, help="Specify a log filename (default: stderr)", ) parser.add_argument( "filenames", nargs="*", type=str, help="Filenames to be classified, supports bz2, gzip, and xz", ) def make_inspect_parser(subparsers): parser = subparsers.add_parser("inspect", help="Inspect Kraken 2 database") parser.add_argument( "--db", type=str, metavar="PATHNAME", required=True, help="Pathname to Kraken2 database", ) # parser.add_argument( # "--threads", # type=int, # default=os.environ.get("KRAKEN2_NUM_THREADS") or 1, # help="Number of threads", # ) parser.add_argument( "--skip-counts", action="store_true", help="Only print database summary statistics", ) parser.add_argument( "--use-mpa-style", action="store_true", help="Format output like Kraken 1's kraken-mpa-report", ) parser.add_argument( "--report-zero-counts", action="store_true", help="Report counts for ALL taxa, even if counts are zero", ) parser.add_argument( "--log", type=str, metavar="FILENAME", default=None, help="Specify a log filename (default: stderr)", ) parser.add_argument( "--output", "--out", type=str, metavar="FILENAME", default="-", help="Write inspect output to FILENAME (default: stdout)" ) parser.add_argument( "--memory-mapping", action="store_true", default=argparse.SUPPRESS, help="Avoids loading entire database into RAM", ) def make_clean_parser(subparsers): parser = subparsers.add_parser( "clean", help="Removes unwanted files from database" ) actions = parser.add_argument_group( "required", "Arguments required by the cleaner" ) mg = actions.add_mutually_exclusive_group(required=True) mg.add_argument( "--stop-daemon", action="store_true", help="Stop a running background process", ) mg.add_argument( "--db", type=str, metavar="PATHNAME", # required=True, help="Pathname to Kraken2 database", ) options = parser.add_argument_group( "options", "options for cleaning temporary files" ) options.add_argument( "--log", type=str, metavar="FILENAME", default=None, help="Specify a log filename (default: stderr)", ) options.add_argument( "--pattern", type=str, metavar="SHELL_REGEX", default=None, help="""Files that match this regular expression will be deleted. ? - A question-mark is a pattern that shall match any character. * - An asterisk is a pattern that shall match multiple characters. [ - The open bracket shall introduce a pattern bracket expression. ** - will match any files and zero or more directories, subdirectories and symbolic links to directories. """ ) class HelpAction(argparse._HelpAction): def __call__(self, parser, namespace, values, option_string=None): parser.print_help() subparsers = None for action in parser._actions: if "choices" in dir(action) and action.choices: subparsers = action.choices if not subparsers: sys.exit(0) for action, arg_parser in subparsers.items(): sys.stderr.write("\n\n" + action + "\n" + "-" * len(action) + "\n") arg_parser.print_help() sys.exit(0) def make_cmdline_parser(): parser = argparse.ArgumentParser("k2", add_help=False) parser.add_argument("-h", "--help", action=HelpAction) parser.add_argument( "-v", "--version", action="version", version=SCRIPT_VERSION ) subparsers = parser.add_subparsers() make_add_to_library_parser(subparsers) make_download_library_parser(subparsers) make_download_taxonomy_parser(subparsers) make_build_parser(subparsers) make_classify_parser(subparsers) make_inspect_parser(subparsers) make_clean_parser(subparsers) return parser class Logger: def __init__(self, filename): self.queue = multiprocessing.Manager().Queue(-1) logging.StreamHandler.terminator = "" self.logger = logging.getLogger("kraken2") if filename: self.logger.setLevel(logging.INFO) handler = logging.FileHandler(filename) formatter = logging.Formatter( "[%(levelname)s - %(asctime)s]: %(message)s" ) handler.setFormatter(formatter) self.logger.addHandler(handler) else: self.logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() formatter = logging.Formatter( "[%(levelname)s - %(asctime)s]: %(message)s" ) handler.setFormatter(formatter) self.logger.addHandler(handler) self.thread = threading.Thread( target=Logger.process_thread, args=(self,), daemon=True ) self.thread.start() def debug(self, log): self.logger.debug(log) def info(self, log): self.logger.info(log) def warning(self, log): self.logger.warning(log) def error(self, log): self.logger.error(log) def process_thread(self): import queue while True: try: record = self.queue.get() if record is None: break except queue.Empty: continue except Exception: break self.logger.handle(record) def stop_thread(self): self.queue.put(None) # self.thread.join() def get_queue(self): return self.queue def get_level(self): return self.logger.level def setup_queue_logger(queue, level): queue_handler = logging.handlers.QueueHandler(queue) logger = logging.getLogger() logger.setLevel(level) logger.addHandler(queue_handler) return logger def __del__(self): try: self.stop_thread() except Exception: pass def k2_main(): global SCRIPT_PATHNAME global LOG SCRIPT_PATHNAME = os.path.realpath(inspect.getsourcefile(k2_main)) parser = make_cmdline_parser() if len(sys.argv) == 1: parser.print_help() sys.exit(1) args = parser.parse_args(sys.argv[1:]) LOG = Logger(args.log) task = sys.argv[1] # if task not in ["classify", "inspect"]: # args.db = os.path.abspath(args.db) if isinstance(args.db, list): args.db = list(map(os.path.abspath, args.db)) elif isinstance(args.db, str): args.db = os.path.abspath(args.db) if task == "download-taxonomy": download_taxonomy(args) elif task == "classify": if len(args.db) > 1: classify_multi_dbs(args) else: args.db = args.db[0] classify(args) elif task == "download-library": download_genomic_library(args) elif task == "add-to-library": add_to_library(args) elif task == "inspect": inspect_db(args) elif task == "clean": clean_db(args) elif task == "build": # Protein defaults default_aa_minimizer_length = 12 default_aa_kmer_length = 15 default_aa_minimizer_spaces = 0 # Nucleotide defaults default_nt_minimizer_length = 31 default_nt_kmer_length = 35 default_nt_minimizer_spaces = 7 if args.sub_block_size == 0: args.sub_block_size = math.ceil(args.block_size / args.threads) if not args.kmer_len: args.kmer_len = ( default_aa_kmer_length if args.protein else default_nt_kmer_length ) if not args.minimizer_len: args.minimizer_len = ( default_aa_minimizer_length if args.protein else default_nt_minimizer_length ) if not args.minimizer_spaces: args.minimizer_spaces = ( default_aa_minimizer_spaces if args.protein else default_nt_minimizer_spaces ) if args.minimizer_len > args.kmer_len: LOG.error( "Minimizer length ({}) must not be greater than kmer " "length {}\n".format(args.minimizer_len, args.kmer_len) ) sys.exit(1) if args.load_factor <= 0 or args.load_factor > 1: LOG.error( "Load factor must be greater than 0 but no more than 1\n" ) sys.exit(1) if args.minimizer_len <= 0 or args.minimizer_len > 31: LOG.error( "Minimizer length must be a positive integer " "and cannot exceed 31.\n" ) sys.exit(1) if args.standard: build_standard_database(args) elif args.special: if args.special == "greengenes": build_16S_gg(args) elif args.special == "silva": build_16S_silva(args) elif args.special == "gtdb": if not args.gtdb_files: LOG.error("Please specify a list of files or pattern of\ the files needed to build a GTDB database.\n") sys.exit(1) build_gtdb_database(args) else: # build_16S_rdp(args) LOG.error("RDP database no longer supported.\n") sys.exit(1) else: if args.no_masking: LOG.warning( "--no-masking only affects the `--standard` and" "`--special` flags. Its effect will be ignored.\n" ) build_kraken2_db(args) if __name__ == "__main__": try: k2_main() except KeyboardInterrupt: pass except Exception: LOG.stop_thread() LOG.error(traceback.format_exc()) sys.exit(1) # else: # LOG.stop_thread()