# Machine learning - Features extraction

Demo to create a feature vector for protein fold classification. 
In this demo we try to classify a protein chain as either an all alpha or all beta protein based on protein sequence. We use n-grams and a Word2Vec representation of the protein sequence as a feature vector.

[Word2Vec model](https://spark.apache.org/docs/latest/mllib-feature-extraction.html#word2vec)

[Word2Vec example](https://spark.apache.org/docs/latest/ml-features.html#word2vec)

## Imports

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.webfilters import Pisces
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains
from mmtfPyspark.datasets import secondaryStructureExtractor
from mmtfPyspark.ml import ProteinSequenceEncoder

## Configure Spark Context

In [2]:
spark = SparkSession.builder.appName("1-Features").getOrCreate()

## Read MMTF File and get a set of L-protein chains

In [3]:
pdb = mmtfReader.read_sequence_file('../resources/mmtf_reduced_sample/') \
                .flatMap(StructureToPolymerChains()) \
                .filter(ContainsLProteinChain())

## Get secondary structure content

In [4]:
data = secondaryStructureExtractor.get_dataset(pdb)

In [5]:
data.show(5)

+----------------+--------------------+----------+----------+----------+--------------------+--------------------+
|structureChainId|            sequence|     alpha|      beta|      coil|          dsspQ8Code|          dsspQ3Code|
+----------------+--------------------+----------+----------+----------+--------------------+--------------------+
|          4WMY.A|TDWSHPQFEKSTDEANT...|0.19081272|0.26855123|0.54063606|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...|
|          4WMY.B|TDWSHPQFEKSTDEANT...|0.17081851|0.26334518| 0.5658363|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...|
|          4WN5.A|GSHMGRGAFLSRHSLDM...| 0.2962963|0.37962964|0.32407406|XXCCCCCCEEEEECTTC...|XXCCCCCCEEEEECCCC...|
|          4WN5.B|GSHMGRGAFLSRHSLDM...|0.33333334|0.37142858| 0.2952381|XXXXXCCCEEEEECTTC...|XXXXXCCCEEEEECCCC...|
|          4WND.A|GPGSMEASCLELALEGE...| 0.8358663|       0.0|0.16413374|XXXXCCSCHHHHHHHHH...|XXXXCCCCHHHHHHHHH...|
+----------------+--------------------+----------+----------+----------+--------

## Define add_protein_fold_type function

In [6]:
def add_protein_fold_type(data, minThreshold, maxThreshold):
    '''
    Adds a column "foldType" with three major secondary structure class:
    "alpha", "beta", "alpha+beta", and "other" based upon the fraction of alpha/beta content.

    The simplified syntax used in this method relies on two imports:
        from pyspark.sql.functions import when
        from pyspark.sql.functions import col

    Attributes:
        data (Dataset<Row>): input dataset with alpha, beta composition
        minThreshold (float): below this threshold, the secondary structure is ignored
        maxThreshold (float): above this threshold, the secondary structure is ignored
    '''

    return data.withColumn("foldType", \
                           when((col("alpha") > maxThreshold) & (col("beta") < minThreshold), "alpha"). \
                           when((col("beta") > maxThreshold) & (col("alpha") < minThreshold), "beta"). \
                           when((col("alpha") > maxThreshold) & (col("beta") > maxThreshold), "alpha+beta"). \
                           otherwise("other")\
                           )

## Classify chains by secondary structure type

In [7]:
data = add_protein_fold_type(data, minThreshold=0.05, maxThreshold=0.15)

In [8]:
data.show()

+----------------+--------------------+-----------+-----------+----------+--------------------+--------------------+----------+
|structureChainId|            sequence|      alpha|       beta|      coil|          dsspQ8Code|          dsspQ3Code|  foldType|
+----------------+--------------------+-----------+-----------+----------+--------------------+--------------------+----------+
|          4WMY.A|TDWSHPQFEKSTDEANT...| 0.19081272| 0.26855123|0.54063606|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...|alpha+beta|
|          4WMY.B|TDWSHPQFEKSTDEANT...| 0.17081851| 0.26334518| 0.5658363|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...|alpha+beta|
|          4WN5.A|GSHMGRGAFLSRHSLDM...|  0.2962963| 0.37962964|0.32407406|XXCCCCCCEEEEECTTC...|XXCCCCCCEEEEECCCC...|alpha+beta|
|          4WN5.B|GSHMGRGAFLSRHSLDM...| 0.33333334| 0.37142858| 0.2952381|XXXXXCCCEEEEECTTC...|XXXXXCCCEEEEECCCC...|alpha+beta|
|          4WND.A|GPGSMEASCLELALEGE...|  0.8358663|        0.0|0.16413374|XXXXCCSCHHHHHHHHH...|XXXXCCCCH

## Create a Word2Vec representation of the protein sequences

**n = 2**     # create 2-grams 

**windowSize = 25**    # 25-amino residue window size for Word2Vector

**vectorSize = 50**    # dimension of feature vector

In [9]:
encoder = ProteinSequenceEncoder(data)
data = encoder.overlapping_ngram_word2vec_encode(n=2, windowSize=25, vectorSize=50).cache()

data.toPandas().head(5)

Unnamed: 0,structureChainId,sequence,alpha,beta,coil,dsspQ8Code,dsspQ3Code,foldType,ngram,features
0,4WMY.A,TDWSHPQFEKSTDEANTYFKEWTCSSSPSLPRSCKEIKDECPSAFD...,0.190813,0.268551,0.540636,XXXXXXXXXXXXXXXXXXXXXXXCCCCCCCCSSHHHHHHHCTTCCS...,XXXXXXXXXXXXXXXXXXXXXXXCCCCCCCCCCHHHHHHHCCCCCC...,alpha+beta,"[TD, DW, WS, SH, HP, PQ, QF, FE, EK, KS, ST, T...","[0.028354697964596942, 0.06656068684991266, 0...."
1,4WMY.B,TDWSHPQFEKSTDEANTYFKEWTCSSSPSLPRSCKEIKDECPSAFD...,0.170819,0.263345,0.565836,XXXXXXXXXXXXXXXXXXXXXCCCXXXXCCCSSHHHHHHHCTTCCS...,XXXXXXXXXXXXXXXXXXXXXCCCXXXXCCCCCHHHHHHHCCCCCC...,alpha+beta,"[TD, DW, WS, SH, HP, PQ, QF, FE, EK, KS, ST, T...","[0.028354697964596942, 0.06656068684991266, 0...."
2,4WN5.A,GSHMGRGAFLSRHSLDMKFTYCDDRIAEVAGYSPDDLIGCSAYEYI...,0.296296,0.37963,0.324074,XXCCCCCCEEEEECTTCBEEEECGGHHHHHSCCHHHHBTSBGGGGB...,XXCCCCCCEEEEECCCCEEEEECHHHHHHHCCCHHHHECCEHHHHE...,alpha+beta,"[GS, SH, HM, MG, GR, RG, GA, AF, FL, LS, SR, R...","[-0.04048257577641491, 0.1233881547426184, 0.3..."
3,4WN5.B,GSHMGRGAFLSRHSLDMKFTYCDDRIAEVAGYSPDDLIGCSAYEYI...,0.333333,0.371429,0.295238,XXXXXCCCEEEEECTTCBEEEECGGHHHHHSCCHHHHBTSBGGGGB...,XXXXXCCCEEEEECCCCEEEEECHHHHHHHCCCHHHHECCEHHHHE...,alpha+beta,"[GS, SH, HM, MG, GR, RG, GA, AF, FL, LS, SR, R...","[-0.04048257577641491, 0.1233881547426184, 0.3..."
4,4WND.A,GPGSMEASCLELALEGERLCKSGDCRAGVSFFEAAVQVGTEDLKTL...,0.835866,0.0,0.164134,XXXXCCSCHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHCCSCHHHH...,XXXXCCCCHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHCCCCHHHH...,alpha,"[GP, PG, GS, SM, ME, EA, AS, SC, CL, LE, EL, L...","[-0.009619595496742813, 0.03677304709491171, 0..."


## Keep only a subset of relevant fields for further processing

In [10]:
data = data.select(['structureChainId','alpha','beta','coil','foldType','features'])

## Write to parquet file

In [11]:
data.write.mode('overwrite').format('parquet').save('./input_features')

## Terminate Spark

In [12]:
spark.stop()