Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

# Natural Language Toolkit: Interface to TADM Classifier 

# 

# Copyright (C) 2001-2012 NLTK Project 

# Author: Joseph Frazee <jfrazee@mail.utexas.edu> 

# URL: <http://www.nltk.org/> 

# For license information, see LICENSE.TXT 

from __future__ import print_function 

 

import sys 

import subprocess 

 

from nltk import compat 

from nltk.internals import find_binary 

try: 

    import numpy 

except ImportError: 

    numpy = None 

 

_tadm_bin = None 

def config_tadm(bin=None): 

    global _tadm_bin 

    _tadm_bin = find_binary( 

        'tadm', bin, 

        env_vars=['TADM_DIR'], 

        binary_names=['tadm'], 

        url='http://tadm.sf.net') 

 

def write_tadm_file(train_toks, encoding, stream): 

    """ 

    Generate an input file for ``tadm`` based on the given corpus of 

    classified tokens. 

 

    :type train_toks: list(tuple(dict, str)) 

    :param train_toks: Training data, represented as a list of 

        pairs, the first member of which is a feature dictionary, 

        and the second of which is a classification label. 

    :type encoding: TadmEventMaxentFeatureEncoding 

    :param encoding: A feature encoding, used to convert featuresets 

        into feature vectors. 

    :type stream: stream 

    :param stream: The stream to which the ``tadm`` input file should be 

        written. 

    """ 

    # See the following for a file format description: 

    # 

    # http://sf.net/forum/forum.php?thread_id=1391502&forum_id=473054 

    # http://sf.net/forum/forum.php?thread_id=1675097&forum_id=473054 

    labels = encoding.labels() 

    for featureset, label in train_toks: 

        stream.write('%d\n' % len(labels)) 

        for known_label in labels: 

            v = encoding.encode(featureset, known_label) 

            stream.write('%d %d %s\n' % (int(label == known_label), len(v), 

                ' '.join('%d %d' % u for u in v))) 

 

def parse_tadm_weights(paramfile): 

    """ 

    Given the stdout output generated by ``tadm`` when training a 

    model, return a ``numpy`` array containing the corresponding weight 

    vector. 

    """ 

    weights = [] 

    for line in paramfile: 

        weights.append(float(line.strip())) 

    return numpy.array(weights, 'd') 

 

def call_tadm(args): 

    """ 

    Call the ``tadm`` binary with the given arguments. 

    """ 

    if isinstance(args, compat.string_types): 

        raise TypeError('args should be a list of strings') 

    if _tadm_bin is None: 

        config_tadm() 

 

    # Call tadm via a subprocess 

    cmd = [_tadm_bin] + args 

    p = subprocess.Popen(cmd, stdout=sys.stdout) 

    (stdout, stderr) = p.communicate() 

 

    # Check the return code. 

    if p.returncode != 0: 

        print() 

        print(stderr) 

        raise OSError('tadm command failed!') 

 

def names_demo(): 

    from nltk.classify.util import names_demo 

    from nltk.classify.maxent import TadmMaxentClassifier 

    classifier = names_demo(TadmMaxentClassifier.train) 

 

def encoding_demo(): 

    import sys 

    from nltk.classify.maxent import TadmEventMaxentFeatureEncoding 

    from nltk.classify.tadm import write_tadm_file 

    tokens = [({'f0':1, 'f1':1, 'f3':1}, 'A'), 

              ({'f0':1, 'f2':1, 'f4':1}, 'B'), 

              ({'f0':2, 'f2':1, 'f3':1, 'f4':1}, 'A')] 

    encoding = TadmEventMaxentFeatureEncoding.train(tokens) 

    write_tadm_file(tokens, encoding, sys.stdout) 

    print() 

    for i in range(encoding.length()): 

        print('%s --> %d' % (encoding.describe(i), i)) 

    print() 

 

if __name__ == '__main__': 

    encoding_demo() 

    names_demo()