Caffe2 - C++ API
A deep learning, cross platform ML framework
fc_inference.cc
1 #include "caffe2/operators/fc_inference.h"
2 
3 namespace caffe2 {
4 std::vector<TensorShape> FCShapeInference(
5  const OperatorDef& def,
6  const vector<TensorShape>& in,
7  bool pretransposed_weight) {
8  vector<TensorShape> out(1);
9  ArgumentHelper helper(def);
10 
11  auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
12  const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
13  auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
14  const int canonical_axis_w =
15  canonical_axis_index_(axis_w, in[1].dims().size());
16  const int N = pretransposed_weight
17  ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
18  : size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
19 
20  vector<int> y_shape(in[0].dims().begin(), in[0].dims().end());
21  CAFFE_ENFORCE_LE(canonical_axis + 1, y_shape.size());
22  y_shape.resize(canonical_axis + 1);
23  y_shape[canonical_axis] = N;
24  out[0] = CreateTensorShape(y_shape, in[0].data_type());
25  return out;
26 }
27 
28 OpSchema::Cost CostInferenceForFC(
29  const OperatorDef& def,
30  const vector<TensorShape>& in) {
31  struct OpSchema::Cost c;
32  ArgumentHelper helper(def);
33 
34  auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
35  const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
36  const int M = size_to_dim_(canonical_axis, GetDimsVector(in[0]));
37  const int K = size_from_dim_(canonical_axis, GetDimsVector(in[0]));
38  auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
39  const int canonical_axis_w =
40  canonical_axis_index_(axis_w, in[1].dims().size());
41  const int N = size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
42  c.flops = 2 * K * M * N + M * N;
43  c.bytes_moved = M * N * sizeof(float);
44  c.params_bytes = (K * N + N) * sizeof(float);
45  return c;
46 }
47 } // namespace caffe2
TIndex size_from_dim_(int k, const vector< TIndex > &dims)
Return product of all dimensions starting from K.
Definition: tensor.h:40
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...