import random from tqdm import tqdm import dataclasses import json import copy from import feature_processing, msa_pairing, msa_identifiers from import pipeline as alpha_pipeline from import parsers from alphafold.common import residue_constants from import pipeline_multimer import numpy as np import tarfile import requests import time import os # Copied from AlphaFold - should recheck where this comes from def process_multimer_features(features_for_chain): all_chain_features = {} for chain_id, chain_features in features_for_chain.items(): all_chain_features[chain_id] = pipeline_multimer.convert_monomer_features( chain_features, chain_id ) all_chain_features = pipeline_multimer.add_assembly_features(all_chain_features) feature_processing.process_unmerged_features(all_chain_features) np_chains_list = list(all_chain_features.values()) pair_msa_sequences = not feature_processing._is_homomer_or_monomer(np_chains_list) chains = list(np_chains_list) chain_keys = chains[0].keys() updated_chains = [] for chain_num, chain in enumerate(chains): new_chain = {k: v for k, v in chain.items() if "_all_seq" not in k} for feature_name in chain_keys: if feature_name.endswith("_all_seq"): feats_padded = msa_pairing.pad_features( chain[feature_name], feature_name ) new_chain[feature_name] = feats_padded new_chain["num_alignments_all_seq"] = np.asarray( len(np_chains_list[chain_num]["msa_all_seq"]) ) updated_chains.append(new_chain) np_chains_list = updated_chains np_chains_list = feature_processing.crop_chains( np_chains_list, msa_crop_size=feature_processing.MSA_CROP_SIZE, pair_msa_sequences=pair_msa_sequences, max_templates=feature_processing.MAX_TEMPLATES, ) np_example = feature_processing.msa_pairing.merge_chain_features( np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences, max_templates=feature_processing.MAX_TEMPLATES, ) np_example = feature_processing.process_final(np_example) np_example = pipeline_multimer.pad_msa(np_example, min_num_seq=512) return np_example # Copied from AlphaFold, adapted code so that duplicate rows in the MSA are not deleted. This gave problems # when pairing MSAs def my_make_msa_features_keep_duplicates(msas): if not msas: raise ValueError('At least one MSA must be provided.') int_msa = [] deletion_matrix = [] species_ids = [] for msa_index, msa in enumerate(msas): if not msa: raise ValueError(f'MSA {msa_index} must contain at least one sequence.') for sequence_index, sequence in enumerate(msa.sequences): int_msa.append( [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) deletion_matrix.append(msa.deletion_matrix[sequence_index]) identifiers = msa_identifiers.get_identifiers( msa.descriptions[sequence_index]) species_ids.append(identifiers.species_id.encode('utf-8')) num_res = len(msas[0].sequences[0]) num_alignments = len(int_msa) features = {} features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) features['msa'] = np.array(int_msa, dtype=np.int32) features['num_alignments'] = np.array([num_alignments] * num_res, dtype=np.int32) features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_) return features # Copied from AlphaFold datapipeline package: modified the process function to use a seq_to_features_cache across # runs (multiple permutations), instead of within one run class Cached_DataPipeline(pipeline_multimer.DataPipeline): # copied from AlphaFold repo, adapted caching def __init__(self, monomer_data_pipeline, jackhmmer_binary_path, uniprot_database_path, use_precomputed_msas, seq_to_features_cache): super(Cached_DataPipeline, self).__init__(monomer_data_pipeline=monomer_data_pipeline, jackhmmer_binary_path=jackhmmer_binary_path, uniprot_database_path=uniprot_database_path, use_precomputed_msas=use_precomputed_msas) self.seq_to_features_cache = seq_to_features_cache def process(self, input_fasta_path: str, msa_output_dir: str) -> alpha_pipeline.FeatureDict: """Runs alignment tools on the input sequences and creates features.""" with open(input_fasta_path) as f: input_fasta_str = input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) chain_id_map = pipeline_multimer._make_chain_id_map(sequences=input_seqs, descriptions=input_descs) chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json') with open(chain_id_map_path, 'w') as f: chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain) for chain_id, fasta_chain in chain_id_map.items()} json.dump(chain_id_map_dict, f, indent=4, sort_keys=True) all_chain_features = {} is_homomer_or_monomer = len(set(input_seqs)) == 1 for chain_id, fasta_chain in chain_id_map.items(): if fasta_chain.sequence in self.seq_to_features_cache: all_chain_features[chain_id] = copy.deepcopy( self.seq_to_features_cache[fasta_chain.sequence]) continue chain_features = self._process_single_chain( chain_id=chain_id, sequence=fasta_chain.sequence, description=fasta_chain.description, msa_output_dir=msa_output_dir, is_homomer_or_monomer=is_homomer_or_monomer) chain_features = pipeline_multimer.convert_monomer_features(chain_features, chain_id=chain_id) all_chain_features[chain_id] = chain_features self.seq_to_features_cache[fasta_chain.sequence] = copy.deepcopy(chain_features) all_chain_features = pipeline_multimer.add_assembly_features(all_chain_features) np_example = feature_processing.pair_and_merge(all_chain_features=all_chain_features) # Pad MSA to avoid zero-sized extra_msa. np_example = pad_msa(np_example, 512) # need this to return np_example # Copied from AlphaFold repo, to use in the pipeline class def pad_msa(np_example, min_num_seq): np_example = dict(np_example) num_seq = np_example['msa'].shape[0] if num_seq < min_num_seq: for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'): np_example[feat] = np.pad( np_example[feat], ((0, min_num_seq - num_seq), (0, 0))) np_example['cluster_bias_mask'] = np.pad( np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),)) return np_example # Exact copy from ColabFold ########################################## # call mmseqs2 ########################################## TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]' def run_mmseqs2(x, prefix, use_env=True, use_filter=True, use_templates=False, filter=None, use_pairing=False, host_url=""): submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa" def submit(seqs, mode, N=101): n, query = N, "" for seq in seqs: query += f">{n}\n{seq}\n" n += 1 res ='{host_url}/{submission_endpoint}', data={'q':query,'mode': mode}) try: out = res.json() except ValueError: print(f"Server didn't reply with json: {res.text}") out = {"status":"ERROR"} return out def status(ID): res = requests.get(f'{host_url}/ticket/{ID}') try: out = res.json() except ValueError: print(f"Server didn't reply with json: {res.text}") out = {"status":"ERROR"} return out def download(ID, path): res = requests.get(f'{host_url}/result/download/{ID}') with open(path,"wb") as out: out.write(res.content) # process input x seqs = [x] if isinstance(x, str) else x # compatibility to old option if filter is not None: use_filter = filter # setup mode if use_filter: mode = "env" if use_env else "all" else: mode = "env-nofilter" if use_env else "nofilter" if use_pairing: mode = "" use_templates = False use_env = False # define path path = f"{prefix}_{mode}" if not os.path.isdir(path): os.mkdir(path) # call mmseqs2 api tar_gz_file = f'{path}/out.tar.gz' N,REDO = 101,True # deduplicate and keep track of order seqs_unique = [] #TODO this might be slow for large sets [seqs_unique.append(x) for x in seqs if x not in seqs_unique] Ms = [N + seqs_unique.index(seq) for seq in seqs] # lets do it! if not os.path.isfile(tar_gz_file): TIME_ESTIMATE = 150 * len(seqs_unique) with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: while REDO: pbar.set_description("SUBMIT") # Resubmit job until it goes through out = submit(seqs_unique, mode, N) while out["status"] in ["UNKNOWN", "RATELIMIT"]: sleep_time = 5 + random.randint(0, 5) print(f"Sleeping for {sleep_time}s. Reason: {out['status']}") # resubmit time.sleep(sleep_time) out = submit(seqs_unique, mode, N) if out["status"] == "ERROR": raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.') if out["status"] == "MAINTENANCE": raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.') # wait for job to finish ID,TIME = out["id"],0 pbar.set_description(out["status"]) while out["status"] in ["UNKNOWN","RUNNING","PENDING"]: t = 5 + random.randint(0,5) print(f"Sleeping for {t}s. Reason: {out['status']}") time.sleep(t) out = status(ID) pbar.set_description(out["status"]) if out["status"] == "RUNNING": TIME += t pbar.update(n=t) #if TIME > 900 and out["status"] != "COMPLETE": # # something failed on the server side, need to resubmit # N += 1 # break if out["status"] == "COMPLETE": if TIME < TIME_ESTIMATE: pbar.update(n=(TIME_ESTIMATE-TIME)) REDO = False if out["status"] == "ERROR": REDO = False raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.') # Download results download(ID, tar_gz_file) # prep list of a3m files if use_pairing: a3m_files = [f"{path}/pair.a3m"] else: a3m_files = [f"{path}/uniref.a3m"] if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m") # extract a3m files if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files): with as tar_gz: tar_gz.extractall(path) # templates if use_templates: templates = {} #print("seq\tpdb\tcid\tevalue") for line in open(f"{path}/pdb70.m8","r"): p = line.rstrip().split() M,pdb,qid,e_value = p[0],p[1],p[2],p[10] M = int(M) if M not in templates: templates[M] = [] templates[M].append(pdb) #if len(templates[M]) <= 20: # print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}") template_paths = {} for k,TMPL in templates.items(): TMPL_PATH = f"{prefix}_{mode}/templates_{k}" if not os.path.isdir(TMPL_PATH): os.mkdir(TMPL_PATH) TMPL_LINE = ",".join(TMPL[:20]) os.system(f"curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/") os.system(f"cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex") os.system(f"touch {TMPL_PATH}/pdb70_cs219.ffdata") template_paths[k] = TMPL_PATH # gather a3m lines a3m_lines = {} for a3m_file in a3m_files: update_M,M = True,None for line in open(a3m_file,"r"): if len(line) > 0: if "\x00" in line: line = line.replace("\x00","") update_M = True if line.startswith(">") and update_M: M = int(line[1:].rstrip()) update_M = False if M not in a3m_lines: a3m_lines[M] = [] a3m_lines[M].append(line) # return results a3m_lines = ["".join(a3m_lines[n]) for n in Ms] if use_templates: template_paths_ = [] for n in Ms: if n not in template_paths: template_paths_.append(None) #print(f"{n-N}\tno_templates_found") else: template_paths_.append(template_paths[n]) template_paths = template_paths_ return (a3m_lines, template_paths) if use_templates else a3m_lines