import string
from gensim import corpora
from collections import defaultdict
from scipy.spatial import distance
import numpy as np
from sklearn import manifold
from nodesAndEdges import writeNodesEdges, readLongestParagraphs
from tqdm import tqdm

npar = 30   # target number of paragraphs per text (30 for a simple demo)
size = []   # actual number of paragraphs per text

documents = readLongestParagraphs('timeMachine.txt', size, npar)    # Herbert Wells
documents += readLongestParagraphs('oliverTwist.txt', size, npar)   # Charles Dickens
documents += readLongestParagraphs('adventuresOfHuckleberryFinn.txt', size, npar)   # Mark Twain
documents += readLongestParagraphs('theWarOfTheWorlds.txt', size, npar)   # Herbert Wells
documents += readLongestParagraphs('astro.txt', size, npar)         # astrophysics paper
documents += readLongestParagraphs('brothersKaramazov.txt', size, npar)   # Fyodor Dostoevsky

authorTag = [1, 2, 3, 1, 4, 5]
novelPerAuthorTag = [1, 1, 1, 2, 1, 1]
yearTag = [1895, 1838, 1884, 1898, 2009, 1880]

# convert line breaks and dashes to spaces, and remove punctuation
for i, p in enumerate(documents):
    tmp = p.replace('\n', ' ').replace('-',' ')
    for c in string.punctuation:
        tmp = tmp.replace(c,'')
    documents[i] = tmp

# remove common words and tokenize (break into words)
stoplist = set('for from a of the and to in at through'.split())
texts = [[word for word in document.lower().split() if word not in stoplist] for document in documents]

# count words across all paragraphs
frequency = defaultdict(int)
for text in texts:
    for token in text:
        frequency[token] += 1

# remove words that appear only once across all paragraphs
texts = [[token for token in text if frequency[token] > 1] for text in texts]

# build a dictionary of words from scratch (not related to above word count)
dictionary = corpora.Dictionary(texts)
# print(dictionary.token2id)   # print IDs of all words in the dictionary
nwords = len(dictionary.token2id)
print('built a global dictionary with', nwords, 'words')

# convert documents to sparse vectors containing tuples (wordID, wordCount);
# corpus is a list of paragraphs, each is a list of tuples
corpus = [dictionary.doc2bow(text) for text in texts]

# convert sparse vectors to full vectors of length nwords
n = sum(size)
fullCorpus = np.zeros((n,nwords), dtype=np.int32)
for i, d in enumerate(corpus):
    for word in d:
        id, count = word
        fullCorpus[i,id] = count

# connect pairs with at least 3 words in common
n, i = len(fullCorpus), -1
edges = []
for d1 in tqdm(fullCorpus):
    i += 1
    row = []
    for j, d2 in enumerate(fullCorpus):
        if i < j:
            if sum((d1!=0) * (d2!=0)) >=5:
                edges.append([i,j])

author, novelPerAuthor, year = [], [], []
for i, s in enumerate(size):
    author += [authorTag[i]] * s
    novelPerAuthor += [novelPerAuthorTag[i]] * s
    year += [yearTag[i]] * s

import networkx as nx
H = nx.Graph()
H.add_nodes_from(range(n))
H.add_edges_from(edges)
pos = nx.spring_layout(H, k=0.9, dim=2)   # dictionary (nodeID,array([x,y,z]))
xyz = np.zeros((n,3))
for i in pos:
    xyz[i,:] = pos[i][0], pos[i][1], year[i]/800.

print(nx.number_of_nodes(H), 'nodes and', nx.number_of_edges(H), 'edges')

print(xyz)
writeNodesEdges(xyz, edges=edges, scalar=[author,novelPerAuthor],
             name=['author','novel per author'], fileout='network8')
