Caffe2 - C++ API
A deep learning, cross platform ML framework
video_decoder.cc
1 #include <caffe2/video/video_decoder.h>
2 #include <caffe2/core/logging.h>
3 
4 #include <stdio.h>
5 #include <mutex>
6 #include <random>
7 
8 extern "C" {
9 #include <libavcodec/avcodec.h>
10 #include <libavformat/avformat.h>
11 #include <libavutil/log.h>
12 #include <libswresample/swresample.h>
13 #include <libswscale/swscale.h>
14 }
15 
16 namespace caffe2 {
17 
18 VideoDecoder::VideoDecoder() {
19  static bool gInitialized = false;
20  static std::mutex gMutex;
21  std::unique_lock<std::mutex> lock(gMutex);
22  if (!gInitialized) {
23  av_register_all();
24  avcodec_register_all();
25  avformat_network_init();
26  gInitialized = true;
27  }
28 }
29 
30 void VideoDecoder::ResizeAndKeepAspectRatio(
31  const int origHeight,
32  const int origWidth,
33  const int heightMin,
34  const int widthMin,
35  int& outHeight,
36  int& outWidth) {
37  float min_aspect = (float)heightMin / (float)widthMin;
38  float video_aspect = (float)origHeight / (float)origWidth;
39  if (video_aspect >= min_aspect) {
40  outWidth = widthMin;
41  outHeight = (int)ceil(video_aspect * outWidth);
42  } else {
43  outHeight = heightMin;
44  outWidth = (int)ceil(outHeight / video_aspect);
45  }
46 }
47 
48 void VideoDecoder::decodeLoop(
49  const string& videoName,
50  VideoIOContext& ioctx,
51  const Params& params,
52  const int start_frm,
53  std::vector<std::unique_ptr<DecodedFrame>>& sampledFrames) {
54  AVPixelFormat pixFormat = params.pixelFormat_;
55  AVFormatContext* inputContext = avformat_alloc_context();
56  AVStream* videoStream_ = nullptr;
57  AVCodecContext* videoCodecContext_ = nullptr;
58  AVFrame* videoStreamFrame_ = nullptr;
59  AVPacket packet;
60  av_init_packet(&packet); // init packet
61  SwsContext* scaleContext_ = nullptr;
62 
63  try {
64  inputContext->pb = ioctx.get_avio();
65  inputContext->flags |= AVFMT_FLAG_CUSTOM_IO;
66  int ret = 0;
67 
68  // Determining the input format:
69  int probeSz = 32 * 1024 + AVPROBE_PADDING_SIZE;
70  DecodedFrame::AvDataPtr probe((uint8_t*)av_malloc(probeSz));
71 
72  memset(probe.get(), 0, probeSz);
73  int len = ioctx.read(probe.get(), probeSz - AVPROBE_PADDING_SIZE);
74  if (len < probeSz - AVPROBE_PADDING_SIZE) {
75  LOG(ERROR) << "Insufficient data to determine video format";
76  return;
77  }
78 
79  // seek back to start of stream
80  ioctx.seek(0, SEEK_SET);
81 
82  unique_ptr<AVProbeData> probeData(new AVProbeData());
83  probeData->buf = probe.get();
84  probeData->buf_size = len;
85  probeData->filename = "";
86  // Determine the input-format:
87  inputContext->iformat = av_probe_input_format(probeData.get(), 1);
88 
89  ret = avformat_open_input(&inputContext, "", nullptr, nullptr);
90  if (ret < 0) {
91  LOG(ERROR) << "Unable to open stream " << ffmpegErrorStr(ret);
92  return;
93  }
94 
95  ret = avformat_find_stream_info(inputContext, nullptr);
96  if (ret < 0) {
97  LOG(ERROR) << "Unable to find stream info in " << videoName << " "
98  << ffmpegErrorStr(ret);
99  return;
100  }
101 
102  // Decode the first video stream
103  int videoStreamIndex_ = params.streamIndex_;
104  if (videoStreamIndex_ == -1) {
105  for (int i = 0; i < inputContext->nb_streams; i++) {
106  auto stream = inputContext->streams[i];
107  if (stream->codec->codec_type == AVMEDIA_TYPE_VIDEO) {
108  videoStreamIndex_ = i;
109  videoStream_ = stream;
110  break;
111  }
112  }
113  }
114 
115  if (videoStream_ == nullptr) {
116  LOG(ERROR) << "Unable to find video stream in " << videoName << " "
117  << ffmpegErrorStr(ret);
118  return;
119  }
120 
121  // Initialize codec
122  AVDictionary* opts = nullptr;
123  videoCodecContext_ = videoStream_->codec;
124  try {
125  ret = avcodec_open2(
126  videoCodecContext_,
127  avcodec_find_decoder(videoCodecContext_->codec_id),
128  &opts);
129  } catch (const std::exception&) {
130  LOG(ERROR) << "Exception during open video codec";
131  return;
132  }
133 
134  if (ret < 0) {
135  LOG(ERROR) << "Cannot open video codec : "
136  << videoCodecContext_->codec->name;
137  return;
138  }
139 
140  // Calculate if we need to rescale the frames
141  int origWidth = videoCodecContext_->width;
142  int origHeight = videoCodecContext_->height;
143  int outWidth = origWidth;
144  int outHeight = origHeight;
145 
146  if (params.video_res_type_ == VideoResType::ORIGINAL_RES) {
147  // if the original resolution is too low,
148  // make its size at least (crop_height, crop_width)
149  if (params.crop_width_ > origWidth || params.crop_height_ > origHeight) {
150  ResizeAndKeepAspectRatio(
151  origHeight,
152  origWidth,
153  params.crop_height_,
154  params.crop_width_,
155  outHeight,
156  outWidth);
157  }
158  } else if (
159  params.video_res_type_ == VideoResType::USE_MINIMAL_WIDTH_HEIGHT) {
160  // resize the image to be at least
161  // (height_min, width_min) resolution while keep the aspect ratio
162  ResizeAndKeepAspectRatio(
163  origHeight,
164  origWidth,
165  params.height_min_,
166  params.width_min_,
167  outHeight,
168  outWidth);
169  } else if (params.video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
170  // resize the image to the predefined
171  // resolution and ignore the aspect ratio
172  outWidth = params.scale_w_;
173  outHeight = params.scale_h_;
174  } else {
175  LOG(ERROR) << "Unknown video_res_type: " << params.video_res_type_;
176  }
177 
178  // Make sure that we have a valid format
179  CAFFE_ENFORCE_NE(videoCodecContext_->pix_fmt, AV_PIX_FMT_NONE);
180 
181  // Create a scale context
182  scaleContext_ = sws_getContext(
183  videoCodecContext_->width,
184  videoCodecContext_->height,
185  videoCodecContext_->pix_fmt,
186  outWidth,
187  outHeight,
188  pixFormat,
189  SWS_FAST_BILINEAR,
190  nullptr,
191  nullptr,
192  nullptr);
193 
194  // Getting video meta data
195  VideoMeta videoMeta;
196  videoMeta.codec_type = videoCodecContext_->codec_type;
197  videoMeta.width = outWidth;
198  videoMeta.height = outHeight;
199  videoMeta.pixFormat = pixFormat;
200  videoMeta.fps = av_q2d(videoStream_->avg_frame_rate);
201 
202  // If sampledFrames is not empty, empty it
203  if (sampledFrames.size() > 0) {
204  sampledFrames.clear();
205  }
206 
207  if (params.intervals_.size() == 0) {
208  LOG(ERROR) << "Empty sampling intervals.";
209  return;
210  }
211 
212  std::vector<SampleInterval>::const_iterator itvlIter =
213  params.intervals_.begin();
214  if (itvlIter->timestamp != 0) {
215  LOG(ERROR) << "Sampling interval starting timestamp is not zero.";
216  }
217 
218  double currFps = itvlIter->fps;
219  if (currFps < 0 && currFps != SpecialFps::SAMPLE_ALL_FRAMES &&
220  currFps != SpecialFps::SAMPLE_TIMESTAMP_ONLY) {
221  // fps must be 0, -1, -2 or > 0
222  LOG(ERROR) << "Invalid sampling fps.";
223  }
224 
225  double prevTimestamp = itvlIter->timestamp;
226  itvlIter++;
227  if (itvlIter != params.intervals_.end() &&
228  prevTimestamp >= itvlIter->timestamp) {
229  LOG(ERROR) << "Sampling interval timestamps must be strictly ascending.";
230  }
231 
232  double lastFrameTimestamp = -1.0;
233  // Initialize frame and packet.
234  // These will be reused across calls.
235  videoStreamFrame_ = av_frame_alloc();
236 
237  // frame index in video stream
238  int frameIndex = -1;
239  // frame index of outputed frames
240  int outputFrameIndex = -1;
241 
242  /* identify the starting point from where we must start decoding */
243  std::mt19937 meta_randgen(time(nullptr));
244  int start_ts = -1;
245  bool mustDecodeAll = false;
246  if (videoStream_->duration > 0 && videoStream_->nb_frames > 0) {
247  /* we have a valid duration and nb_frames. We can safely
248  * detect an intermediate timestamp to start decoding from. */
249 
250  // leave a margin of 10 frames to take in to account the error
251  // from av_seek_frame
252  int margin =
253  int(ceil((10 * videoStream_->duration) / (videoStream_->nb_frames)));
254  // if we need to do temporal jittering
255  if (params.decode_type_ == DecodeType::DO_TMP_JITTER) {
256  /* estimate the average duration for the required # of frames */
257  double maxFramesDuration =
258  (videoStream_->duration * params.num_of_required_frame_) /
259  (videoStream_->nb_frames);
260  int ts1 = 0;
261  int ts2 = videoStream_->duration - int(ceil(maxFramesDuration));
262  ts2 = ts2 > 0 ? ts2 : 0;
263  // pick a random timestamp between ts1 and ts2. ts2 is selected such
264  // that you have enough frames to satisfy the required # of frames.
265  start_ts = std::uniform_int_distribution<>(ts1, ts2)(meta_randgen);
266  // seek a frame at start_ts
267  ret = av_seek_frame(
268  inputContext,
269  videoStreamIndex_,
270  std::max(0, start_ts - margin),
271  AVSEEK_FLAG_BACKWARD);
272 
273  // if we need to decode from the start_frm
274  } else if (params.decode_type_ == DecodeType::USE_START_FRM) {
275  start_ts = int(floor(
276  (videoStream_->duration * start_frm) / (videoStream_->nb_frames)));
277  // seek a frame at start_ts
278  ret = av_seek_frame(
279  inputContext,
280  videoStreamIndex_,
281  std::max(0, start_ts - margin),
282  AVSEEK_FLAG_BACKWARD);
283  } else {
284  mustDecodeAll = true;
285  }
286 
287  if (ret < 0) {
288  LOG(ERROR) << "Unable to decode from a random start point";
289  /* fall back to default decoding of all frames from start */
290  av_seek_frame(inputContext, videoStreamIndex_, 0, AVSEEK_FLAG_BACKWARD);
291  mustDecodeAll = true;
292  }
293  } else {
294  /* we do not have the necessary metadata to selectively decode frames.
295  * Decode all frames as we do in the default case */
296  LOG(INFO) << " Decoding all frames as we do not have suffiecient"
297  " metadata for selective decoding.";
298  mustDecodeAll = true;
299  }
300 
301  int gotPicture = 0;
302  int eof = 0;
303  int selectiveDecodedFrames = 0;
304 
305  int maxFrames = (params.decode_type_ == DecodeType::DO_UNIFORM_SMP)
306  ? MAX_DECODING_FRAMES
307  : params.num_of_required_frame_;
308  // There is a delay between reading packets from the
309  // transport and getting decoded frames back.
310  // Therefore, after EOF, continue going while
311  // the decoder is still giving us frames.
312  while ((!eof || gotPicture) &&
313  /* either you must decode all frames or decode upto maxFrames
314  * based on status of the mustDecodeAll flag */
315  (mustDecodeAll ||
316  ((!mustDecodeAll) && (selectiveDecodedFrames < maxFrames)))) {
317  try {
318  if (!eof) {
319  ret = av_read_frame(inputContext, &packet);
320 
321  if (ret == AVERROR(EAGAIN)) {
322  av_free_packet(&packet);
323  continue;
324  }
325  // Interpret any other error as EOF
326  if (ret < 0) {
327  eof = 1;
328  av_free_packet(&packet);
329  continue;
330  }
331 
332  // Ignore packets from other streams
333  if (packet.stream_index != videoStreamIndex_) {
334  av_free_packet(&packet);
335  continue;
336  }
337  }
338 
339  ret = avcodec_decode_video2(
340  videoCodecContext_, videoStreamFrame_, &gotPicture, &packet);
341  if (ret < 0) {
342  LOG(ERROR) << "Error decoding video frame : " << ffmpegErrorStr(ret);
343  }
344 
345  try {
346  // Nothing to do without a picture
347  if (!gotPicture) {
348  av_free_packet(&packet);
349  continue;
350  }
351  frameIndex++;
352 
353  double frame_ts =
354  av_frame_get_best_effort_timestamp(videoStreamFrame_);
355  double timestamp = frame_ts * av_q2d(videoStream_->time_base);
356 
357  if ((frame_ts >= start_ts && !mustDecodeAll) || mustDecodeAll) {
358  /* process current frame if:
359  * 1) We are not doing selective decoding and mustDecodeAll
360  * OR
361  * 2) We are doing selective decoding and current frame
362  * timestamp is >= start_ts from where we start selective
363  * decoding*/
364  // if reaching the next interval, update the current fps
365  // and reset lastFrameTimestamp so the current frame could be
366  // sampled (unless fps == SpecialFps::SAMPLE_NO_FRAME)
367  if (itvlIter != params.intervals_.end() &&
368  timestamp >= itvlIter->timestamp) {
369  lastFrameTimestamp = -1.0;
370  currFps = itvlIter->fps;
371  prevTimestamp = itvlIter->timestamp;
372  itvlIter++;
373  if (itvlIter != params.intervals_.end() &&
374  prevTimestamp >= itvlIter->timestamp) {
375  LOG(ERROR)
376  << "Sampling interval timestamps must be strictly ascending.";
377  }
378  }
379 
380  // keyFrame will bypass all checks on fps sampling settings
381  bool keyFrame = params.keyFrames_ && videoStreamFrame_->key_frame;
382  if (!keyFrame) {
383  // if fps == SpecialFps::SAMPLE_NO_FRAME (0), don't sample at all
384  if (currFps == SpecialFps::SAMPLE_NO_FRAME) {
385  av_free_packet(&packet);
386  continue;
387  }
388 
389  // fps is considered reached in the following cases:
390  // 1. lastFrameTimestamp < 0 - start of a new interval
391  // (or first frame)
392  // 2. currFps == SpecialFps::SAMPLE_ALL_FRAMES (-1) - sample every
393  // frame
394  // 3. timestamp - lastFrameTimestamp has reached target fps and
395  // currFps > 0 (not special fps setting)
396  // different modes for fps:
397  // SpecialFps::SAMPLE_NO_FRAMES (0):
398  // disable fps sampling, no frame sampled at all
399  // SpecialFps::SAMPLE_ALL_FRAMES (-1):
400  // unlimited fps sampling, will sample at native video fps
401  // SpecialFps::SAMPLE_TIMESTAMP_ONLY (-2):
402  // disable fps sampling, but will get the frame at specific
403  // timestamp
404  // others (> 0): decoding at the specified fps
405  bool fpsReached = lastFrameTimestamp < 0 ||
406  currFps == SpecialFps::SAMPLE_ALL_FRAMES ||
407  (currFps > 0 &&
408  timestamp >= lastFrameTimestamp + (1 / currFps));
409 
410  if (!fpsReached) {
411  av_free_packet(&packet);
412  continue;
413  }
414  }
415 
416  lastFrameTimestamp = timestamp;
417 
418  outputFrameIndex++;
419  if (params.maximumOutputFrames_ != -1 &&
420  outputFrameIndex >= params.maximumOutputFrames_) {
421  // enough frames
422  av_free_packet(&packet);
423  break;
424  }
425 
426  AVFrame* rgbFrame = av_frame_alloc();
427  if (!rgbFrame) {
428  LOG(ERROR) << "Error allocating AVframe";
429  }
430 
431  try {
432  // Determine required buffer size and allocate buffer
433  int numBytes = avpicture_get_size(pixFormat, outWidth, outHeight);
434  DecodedFrame::AvDataPtr buffer(
435  (uint8_t*)av_malloc(numBytes * sizeof(uint8_t)));
436 
437  int size = avpicture_fill(
438  (AVPicture*)rgbFrame,
439  buffer.get(),
440  pixFormat,
441  outWidth,
442  outHeight);
443 
444  sws_scale(
445  scaleContext_,
446  videoStreamFrame_->data,
447  videoStreamFrame_->linesize,
448  0,
449  videoCodecContext_->height,
450  rgbFrame->data,
451  rgbFrame->linesize);
452 
453  unique_ptr<DecodedFrame> frame = make_unique<DecodedFrame>();
454  frame->width_ = outWidth;
455  frame->height_ = outHeight;
456  frame->data_ = move(buffer);
457  frame->size_ = size;
458  frame->index_ = frameIndex;
459  frame->outputFrameIndex_ = outputFrameIndex;
460  frame->timestamp_ = timestamp;
461  frame->keyFrame_ = videoStreamFrame_->key_frame;
462 
463  sampledFrames.push_back(move(frame));
464  selectiveDecodedFrames++;
465  av_frame_free(&rgbFrame);
466  } catch (const std::exception&) {
467  av_frame_free(&rgbFrame);
468  }
469  }
470  av_frame_unref(videoStreamFrame_);
471  } catch (const std::exception&) {
472  av_frame_unref(videoStreamFrame_);
473  }
474 
475  av_free_packet(&packet);
476  } catch (const std::exception&) {
477  av_free_packet(&packet);
478  }
479  } // of while loop
480 
481  // free all stuffs
482  sws_freeContext(scaleContext_);
483  av_packet_unref(&packet);
484  av_frame_free(&videoStreamFrame_);
485  avcodec_close(videoCodecContext_);
486  avformat_close_input(&inputContext);
487  avformat_free_context(inputContext);
488  } catch (const std::exception&) {
489  // In case of decoding error
490  // free all stuffs
491  sws_freeContext(scaleContext_);
492  av_packet_unref(&packet);
493  av_frame_free(&videoStreamFrame_);
494  avcodec_close(videoCodecContext_);
495  avformat_close_input(&inputContext);
496  avformat_free_context(inputContext);
497  }
498 }
499 
500 void VideoDecoder::decodeMemory(
501  const char* buffer,
502  const int size,
503  const Params& params,
504  const int start_frm,
505  std::vector<std::unique_ptr<DecodedFrame>>& sampledFrames) {
506  VideoIOContext ioctx(buffer, size);
507  decodeLoop(string("Memory Buffer"), ioctx, params, start_frm, sampledFrames);
508 }
509 
510 void VideoDecoder::decodeFile(
511  const string& file,
512  const Params& params,
513  const int start_frm,
514  std::vector<std::unique_ptr<DecodedFrame>>& sampledFrames) {
515  VideoIOContext ioctx(file);
516  decodeLoop(file, ioctx, params, start_frm, sampledFrames);
517 }
518 
519 string VideoDecoder::ffmpegErrorStr(int result) {
520  std::array<char, 128> buf;
521  av_strerror(result, buf.data(), buf.size());
522  return string(buf.data());
523 }
524 
525 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...