{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Machine learning - Features extraction\n", "\n", "Demo to create a feature vector for protein fold classification. \n", "In this demo we try to classify a protein chain as either an all alpha or all beta protein based on protein sequence. We use n-grams and a Word2Vec representation of the protein sequence as a feature vector.\n", "\n", "[Word2Vec model](https://spark.apache.org/docs/latest/mllib-feature-extraction.html#word2vec)\n", "\n", "[Word2Vec example](https://spark.apache.org/docs/latest/ml-features.html#word2vec)\n", "\n", "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql import SparkSession\n", "from pyspark.sql.functions import *\n", "from pyspark.sql.types import *\n", "from mmtfPyspark.io import mmtfReader\n", "from mmtfPyspark.webfilters import Pisces\n", "from mmtfPyspark.filters import ContainsLProteinChain\n", "from mmtfPyspark.mappers import StructureToPolymerChains\n", "from mmtfPyspark.datasets import secondaryStructureExtractor\n", "from mmtfPyspark.ml import ProteinSequenceEncoder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configure Spark Context" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "spark = SparkSession.builder.appName(\"1-Features\").getOrCreate()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Read MMTF File and get a set of L-protein chains" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "pdb = mmtfReader.read_sequence_file('../resources/mmtf_reduced_sample/') \\\n", " .flatMap(StructureToPolymerChains()) \\\n", " .filter(ContainsLProteinChain())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get secondary structure content" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "data = secondaryStructureExtractor.get_dataset(pdb)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------+--------------------+----------+----------+----------+--------------------+--------------------+\n", "|structureChainId| sequence| alpha| beta| coil| dsspQ8Code| dsspQ3Code|\n", "+----------------+--------------------+----------+----------+----------+--------------------+--------------------+\n", "| 4WMY.A|TDWSHPQFEKSTDEANT...|0.19081272|0.26855123|0.54063606|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...|\n", "| 4WMY.B|TDWSHPQFEKSTDEANT...|0.17081851|0.26334518| 0.5658363|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...|\n", "| 4WN5.A|GSHMGRGAFLSRHSLDM...| 0.2962963|0.37962964|0.32407406|XXCCCCCCEEEEECTTC...|XXCCCCCCEEEEECCCC...|\n", "| 4WN5.B|GSHMGRGAFLSRHSLDM...|0.33333334|0.37142858| 0.2952381|XXXXXCCCEEEEECTTC...|XXXXXCCCEEEEECCCC...|\n", "| 4WND.A|GPGSMEASCLELALEGE...| 0.8358663| 0.0|0.16413374|XXXXCCSCHHHHHHHHH...|XXXXCCCCHHHHHHHHH...|\n", "+----------------+--------------------+----------+----------+----------+--------------------+--------------------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "data.show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define add_protein_fold_type function" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def add_protein_fold_type(data, minThreshold, maxThreshold):\n", " '''\n", " Adds a column \"foldType\" with three major secondary structure class:\n", " \"alpha\", \"beta\", \"alpha+beta\", and \"other\" based upon the fraction of alpha/beta content.\n", "\n", " The simplified syntax used in this method relies on two imports:\n", " from pyspark.sql.functions import when\n", " from pyspark.sql.functions import col\n", "\n", " Attributes:\n", " data (Dataset): input dataset with alpha, beta composition\n", " minThreshold (float): below this threshold, the secondary structure is ignored\n", " maxThreshold (float): above this threshold, the secondary structure is ignored\n", " '''\n", "\n", " return data.withColumn(\"foldType\", \\\n", " when((col(\"alpha\") > maxThreshold) & (col(\"beta\") < minThreshold), \"alpha\"). \\\n", " when((col(\"beta\") > maxThreshold) & (col(\"alpha\") < minThreshold), \"beta\"). \\\n", " when((col(\"alpha\") > maxThreshold) & (col(\"beta\") > maxThreshold), \"alpha+beta\"). \\\n", " otherwise(\"other\")\\\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classify chains by secondary structure type" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "data = add_protein_fold_type(data, minThreshold=0.05, maxThreshold=0.15)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------+--------------------+-----------+-----------+----------+--------------------+--------------------+----------+\n", "|structureChainId| sequence| alpha| beta| coil| dsspQ8Code| dsspQ3Code| foldType|\n", "+----------------+--------------------+-----------+-----------+----------+--------------------+--------------------+----------+\n", "| 4WMY.A|TDWSHPQFEKSTDEANT...| 0.19081272| 0.26855123|0.54063606|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...|alpha+beta|\n", "| 4WMY.B|TDWSHPQFEKSTDEANT...| 0.17081851| 0.26334518| 0.5658363|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...|alpha+beta|\n", "| 4WN5.A|GSHMGRGAFLSRHSLDM...| 0.2962963| 0.37962964|0.32407406|XXCCCCCCEEEEECTTC...|XXCCCCCCEEEEECCCC...|alpha+beta|\n", "| 4WN5.B|GSHMGRGAFLSRHSLDM...| 0.33333334| 0.37142858| 0.2952381|XXXXXCCCEEEEECTTC...|XXXXXCCCEEEEECCCC...|alpha+beta|\n", "| 4WND.A|GPGSMEASCLELALEGE...| 0.8358663| 0.0|0.16413374|XXXXCCSCHHHHHHHHH...|XXXXCCCCHHHHHHHHH...| alpha|\n", "| 4WND.B|GPLGSDLPPKVVPSKQL...|0.115384616| 0.0|0.88461536|XXXXXXXXXXXXXXXCC...|XXXXXXXXXXXXXXXCC...| other|\n", "| 4WP6.A|GSHHHHHHSQDPMQAAQ...| 0.45695364|0.119205296|0.42384106|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...| other|\n", "| 4WP9.A|FQGAMGSRVVILFTDIE...| 0.3939394| 0.3151515|0.29090908|XXCCSSEEEEEEEEEET...|XXCCCCEEEEEEEEEEC...|alpha+beta|\n", "| 4WP9.B|FQGAMGSRVVILFTDIE...| 0.4| 0.3125| 0.2875|XXXCCSEEEEEEEEEET...|XXXCCCEEEEEEEEEEC...|alpha+beta|\n", "| 4WPG.A|GPLLEMILITGSNGQLG...| 0.39372823| 0.17073171|0.43554008|XCCSCCEEEESTTSHHH...|XCCCCCEEEECCCCHHH...|alpha+beta|\n", "| 4WPK.A|MHHHHHHGMASMTARPL...| 0.4122807|0.114035085|0.47368422|XXXXXXXXXXCTTTSCH...|XXXXXXXXXXCCCCCCH...| other|\n", "| 4WQD.A|MEPPTVALTVPAAALLP...| 0.3991228|0.057017542|0.54385966|XXXXCBCCCCCCGGGCC...|XXXXCECCCCCCHHHCC...| other|\n", "| 4WRI.A|GILANLKEPSAHWCRKM...| 0.62032086|0.053475935| 0.3262032|XXXXXCCCCCHHHHHHH...|XXXXXCCCCCHHHHHHH...| other|\n", "| 4WSF.A|TTDTRRRVKLYALNAER...| 0.16216215| 0.4774775|0.36036035|XXCCTTEEEEEEECTTS...|XXCCCCEEEEEEECCCC...|alpha+beta|\n", "| 4WSF.B| PDESSADVVFKKPLAPAPR| 0.0| 0.0| 1.0| XXXXXXXCCSCCCSSCCCX| XXXXXXXCCCCCCCCCCCX| other|\n", "| 1GWM.A|MNVRATYTVIFKNASGL...|0.039215688| 0.503268|0.45751634|CCCSCCEEEEESSCSSS...|CCCCCCEEEEECCCCCC...| beta|\n", "| 1GXM.A|GLVPRGSHMTGRMLTLD...| 0.42901236| 0.13580246| 0.4351852|XXXXXXXXCBTTBCCCT...|XXXXXXXXCECCECCCC...| other|\n", "| 1GXM.B|GLVPRGSHMTGRMLTLD...| 0.4186747| 0.12951808|0.45180723|CCCCTTTTCBTTBCCCT...|CCCCCCCCCECCECCCC...| other|\n", "| 1GXR.A|DYFQGAMGSKPAYSFHV...| 0.0| 0.5432836|0.45671642|CCEEEEEEEEECCEEEE...|CCEEEEEEEEECCEEEE...| beta|\n", "| 1GXR.B|DYFQGAMGSKPAYSFHV...| 0.0| 0.5555556|0.44444445|CCEEEEEEEEECCEEET...|CCEEEEEEEEECCEEEC...| beta|\n", "+----------------+--------------------+-----------+-----------+----------+--------------------+--------------------+----------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "data.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a Word2Vec representation of the protein sequences\n", "\n", "**n = 2** # create 2-grams \n", "\n", "**windowSize = 25** # 25-amino residue window size for Word2Vector\n", "\n", "**vectorSize = 50** # dimension of feature vector" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
structureChainIdsequencealphabetacoildsspQ8CodedsspQ3CodefoldTypengramfeatures
04WMY.ATDWSHPQFEKSTDEANTYFKEWTCSSSPSLPRSCKEIKDECPSAFD...0.1908130.2685510.540636XXXXXXXXXXXXXXXXXXXXXXXCCCCCCCCSSHHHHHHHCTTCCS...XXXXXXXXXXXXXXXXXXXXXXXCCCCCCCCCCHHHHHHHCCCCCC...alpha+beta[TD, DW, WS, SH, HP, PQ, QF, FE, EK, KS, ST, T...[0.028354697964596942, 0.06656068684991266, 0....
14WMY.BTDWSHPQFEKSTDEANTYFKEWTCSSSPSLPRSCKEIKDECPSAFD...0.1708190.2633450.565836XXXXXXXXXXXXXXXXXXXXXCCCXXXXCCCSSHHHHHHHCTTCCS...XXXXXXXXXXXXXXXXXXXXXCCCXXXXCCCCCHHHHHHHCCCCCC...alpha+beta[TD, DW, WS, SH, HP, PQ, QF, FE, EK, KS, ST, T...[0.028354697964596942, 0.06656068684991266, 0....
24WN5.AGSHMGRGAFLSRHSLDMKFTYCDDRIAEVAGYSPDDLIGCSAYEYI...0.2962960.3796300.324074XXCCCCCCEEEEECTTCBEEEECGGHHHHHSCCHHHHBTSBGGGGB...XXCCCCCCEEEEECCCCEEEEECHHHHHHHCCCHHHHECCEHHHHE...alpha+beta[GS, SH, HM, MG, GR, RG, GA, AF, FL, LS, SR, R...[-0.04048257577641491, 0.1233881547426184, 0.3...
34WN5.BGSHMGRGAFLSRHSLDMKFTYCDDRIAEVAGYSPDDLIGCSAYEYI...0.3333330.3714290.295238XXXXXCCCEEEEECTTCBEEEECGGHHHHHSCCHHHHBTSBGGGGB...XXXXXCCCEEEEECCCCEEEEECHHHHHHHCCCHHHHECCEHHHHE...alpha+beta[GS, SH, HM, MG, GR, RG, GA, AF, FL, LS, SR, R...[-0.04048257577641491, 0.1233881547426184, 0.3...
44WND.AGPGSMEASCLELALEGERLCKSGDCRAGVSFFEAAVQVGTEDLKTL...0.8358660.0000000.164134XXXXCCSCHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHCCSCHHHH...XXXXCCCCHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHCCCCHHHH...alpha[GP, PG, GS, SM, ME, EA, AS, SC, CL, LE, EL, L...[-0.009619595496742813, 0.03677304709491171, 0...
\n", "
" ], "text/plain": [ " structureChainId sequence \\\n", "0 4WMY.A TDWSHPQFEKSTDEANTYFKEWTCSSSPSLPRSCKEIKDECPSAFD... \n", "1 4WMY.B TDWSHPQFEKSTDEANTYFKEWTCSSSPSLPRSCKEIKDECPSAFD... \n", "2 4WN5.A GSHMGRGAFLSRHSLDMKFTYCDDRIAEVAGYSPDDLIGCSAYEYI... \n", "3 4WN5.B GSHMGRGAFLSRHSLDMKFTYCDDRIAEVAGYSPDDLIGCSAYEYI... \n", "4 4WND.A GPGSMEASCLELALEGERLCKSGDCRAGVSFFEAAVQVGTEDLKTL... \n", "\n", " alpha beta coil \\\n", "0 0.190813 0.268551 0.540636 \n", "1 0.170819 0.263345 0.565836 \n", "2 0.296296 0.379630 0.324074 \n", "3 0.333333 0.371429 0.295238 \n", "4 0.835866 0.000000 0.164134 \n", "\n", " dsspQ8Code \\\n", "0 XXXXXXXXXXXXXXXXXXXXXXXCCCCCCCCSSHHHHHHHCTTCCS... \n", "1 XXXXXXXXXXXXXXXXXXXXXCCCXXXXCCCSSHHHHHHHCTTCCS... \n", "2 XXCCCCCCEEEEECTTCBEEEECGGHHHHHSCCHHHHBTSBGGGGB... \n", "3 XXXXXCCCEEEEECTTCBEEEECGGHHHHHSCCHHHHBTSBGGGGB... \n", "4 XXXXCCSCHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHCCSCHHHH... \n", "\n", " dsspQ3Code foldType \\\n", "0 XXXXXXXXXXXXXXXXXXXXXXXCCCCCCCCCCHHHHHHHCCCCCC... alpha+beta \n", "1 XXXXXXXXXXXXXXXXXXXXXCCCXXXXCCCCCHHHHHHHCCCCCC... alpha+beta \n", "2 XXCCCCCCEEEEECCCCEEEEECHHHHHHHCCCHHHHECCEHHHHE... alpha+beta \n", "3 XXXXXCCCEEEEECCCCEEEEECHHHHHHHCCCHHHHECCEHHHHE... alpha+beta \n", "4 XXXXCCCCHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHCCCCHHHH... alpha \n", "\n", " ngram \\\n", "0 [TD, DW, WS, SH, HP, PQ, QF, FE, EK, KS, ST, T... \n", "1 [TD, DW, WS, SH, HP, PQ, QF, FE, EK, KS, ST, T... \n", "2 [GS, SH, HM, MG, GR, RG, GA, AF, FL, LS, SR, R... \n", "3 [GS, SH, HM, MG, GR, RG, GA, AF, FL, LS, SR, R... \n", "4 [GP, PG, GS, SM, ME, EA, AS, SC, CL, LE, EL, L... \n", "\n", " features \n", "0 [0.028354697964596942, 0.06656068684991266, 0.... \n", "1 [0.028354697964596942, 0.06656068684991266, 0.... \n", "2 [-0.04048257577641491, 0.1233881547426184, 0.3... \n", "3 [-0.04048257577641491, 0.1233881547426184, 0.3... \n", "4 [-0.009619595496742813, 0.03677304709491171, 0... " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoder = ProteinSequenceEncoder(data)\n", "data = encoder.overlapping_ngram_word2vec_encode(n=2, windowSize=25, vectorSize=50).cache()\n", "\n", "data.toPandas().head(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Keep only a subset of relevant fields for further processing" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "data = data.select(['structureChainId','alpha','beta','coil','foldType','features'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Write to parquet file" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "data.write.mode('overwrite').format('parquet').save('./input_features')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Terminate Spark" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "spark.stop()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 4 }