#!/usr/bin/env python # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import heapq import math import numpy as np class Caption(object): """ A complete or partial caption object """ def __init__(self, sentence, state, logprob, score): """Initializes the Caption""" # list of word_ids in the caption self.sentence = sentence # model state after generating the previous word self.state = state # log probability of the caption self.logprob = logprob # score of the caption self.score = score def __cmp__(self, other): """Compares Captions by score""" assert isinstance(other, Caption) if self.score == other.score: return 0 elif self.score < other.score: return -1 else: return 1 # for Python 3 compatibility (__cmp__ is deprecated). def __lt__(self, other): assert isinstance(other, Caption) return self.score < other.score # also for Python 3 compatibility. def __eq__(self, other): assert isinstance(other, Caption) return self.score == other.score class TopN(object): """Maintains the top N elements of an incrementally provided set""" def __init__(self, n): self._n = n self._data = [] def size(self): assert self._data is not None return len(self._data) def push(self, x): """Pushes a new element""" assert self._data is not None if len(self._data) < self._n: heapq.heappush(self._data, x) else: heapq.heappushpop(self._data, x) def extract(self, sort=False): """ Extracts all elements from the TopN. This is a destructive operation, The only method that can be called immediately after extract() is reset() """ assert self._data is not None data = self._data self._data = None if sort: data.sort(reverse=True) return data def reset(self): """Returns the TopN to an empty state""" self._data = [] class CaptionGenerator(object): """ Class to generate captions from an image-to-text model """ def __init__(self, model, vocab, beam_size, max_caption_length, length_normalization_factor=0.0): self.vocab = vocab self.model = model self.beam_size = beam_size self.max_caption_length = max_caption_length self.length_normalization_factor = length_normalization_factor def beam_search(self, sess, encoded_image): """Runs beam search caption generation on a single image""" # feed in the image to get the initial state. initial_state = self.model.feed_image(sess, encoded_image) initial_beam = Caption( sentence=[self.vocab.start_id], state=initial_state[0], logprob=0.0, score=0.0) partial_captions = TopN(self.beam_size) partial_captions.push(initial_beam) complete_captions = TopN(self.beam_size) # run beam search. for _ in range(self.max_caption_length - 1): partial_captions_list = partial_captions.extract() partial_captions.reset() input_feed = np.array([c.sentence[-1] for c in partial_captions_list]) state_feed = np.array([c.state for c in partial_captions_list]) softmax, new_states = self.model.inference_step(sess, input_feed, state_feed) for i, partial_caption in enumerate(partial_captions_list): word_probabilities = softmax[i] state = new_states[i] # for this partial caption, get the beam_size most probable next words. words_and_probs = list(enumerate(word_probabilities)) words_and_probs.sort(key=lambda x: -x[1]) words_and_probs = words_and_probs[0:self.beam_size] # each next word gives a new partial caption. for w, p in words_and_probs: if p < 1e-12: continue # avoid log(0). sentence = partial_caption.sentence + [w] logprob = partial_caption.logprob + math.log(p) score = logprob if w == self.vocab.end_id: if self.length_normalization_factor > 0: score /= len(sentence) ** self.length_normalization_factor beam = Caption(sentence, state, logprob, score) complete_captions.push(beam) else: beam = Caption(sentence, state, logprob, score) partial_captions.push(beam) if partial_captions.size() == 0: # we have run out of partial candidates; happens when beam_size = 1. break # if we have no complete captions then fall back to the partial captions, # but never output a mixture of complete and partial captions because a # partial caption could have a higher score than all the complete captions if not complete_captions.size(): complete_captions = partial_captions return complete_captions.extract(sort=True)