Unsorted segment reduction op with optional fused embedding lookup. More...
#include <segment_reduction_op.h>
Public Types | |
| enum | { INDICES = Reducer::kInputCount, SEGMENT_IDS = Reducer::kInputCount + (SparseFused ? 1 : 0) } |
Public Types inherited from caffe2::Observable< OperatorBase > | |
| using | Observer = ObserverBase< OperatorBase > |
Public Member Functions | |
| AbstractUnsortedSegmentOp (const OperatorDef &operator_def, Workspace *ws) | |
| bool | RunOnDevice () override |
| template<typename IndexType > | |
| bool | DoRunWithType () |
| template<typename IndexType , int FixedSize> | |
| bool | DoRunWithValue () |
Public Member Functions inherited from caffe2::Operator< Context > | |
| Operator (const OperatorDef &operator_def, Workspace *ws) | |
| const Tensor< Context > & | Input (int idx) |
| Tensor< Context > * | Output (int idx) |
| void | WaitEvent (const Event &ev, int stream_id=-1) final |
| void | WaitEvents (const std::vector< const Event * > &events, int stream_id=-1) final |
| bool | Run (int stream_id=0) final |
| bool | RunAsync (int stream_id=0) final |
| bool | IsStreamFree (int stream_id) const override |
| bool | HasAsyncPart () const override |
| bool | SupportsAsyncScheduling () const override |
| const Context * | getContext () const |
Public Member Functions inherited from caffe2::OperatorBase | |
| OperatorBase (const OperatorDef &operator_def, Workspace *ws) | |
| bool | HasArgument (const string &name) const |
| Checks if the operator has an argument of the given name. | |
| template<typename T > | |
| T | GetSingleArgument (const string &name, const T &default_value) const |
| template<typename T > | |
| bool | HasSingleArgumentOfType (const string &name) const |
| template<typename T > | |
| vector< T > | GetRepeatedArgument (const string &name, const vector< T > &default_value={}) const |
| template<typename T > | |
| const T & | Input (int idx) |
| template<typename T > | |
| T * | Output (int idx) |
| template<typename T > | |
| T * | Output (int idx, T *allocated) |
| const Blob & | InputBlob (int idx) |
| Blob * | OutputBlob (int idx) |
| template<typename T > | |
| bool | InputIsType (int idx) |
| template<typename T > | |
| bool | OutputIsType (int idx) |
| int | InputSize () const |
| int | OutputSize () const |
| const vector< const Blob * > & | Inputs () const |
| const vector< Blob * > & | Outputs () |
| vector< TensorShape > | InputTensorShapes () |
| void | Wait (const OperatorBase &other, int stream_id=-1) |
| virtual void | Finish () |
| virtual void | AddRelatedBlobInfo (EnforceNotMet *err) |
| const OperatorDef & | debug_def () const |
| void | set_debug_def (const std::shared_ptr< const OperatorDef > &operator_def) |
| bool | has_debug_def () const |
| void | RecordLastFailedOpNetPosition () |
| int | net_position () const |
| void | set_net_position (int idx) |
| const DeviceOption & | device_option () const |
| const Event & | event () const |
| Event & | event () |
| void | ResetEvent () |
| void | DisableEvent () |
| bool | IsEventDisabled () const |
| const std::string & | type () const |
| void | annotate_engine (const std::string &engine) |
| const std::string & | engine () const |
Public Member Functions inherited from caffe2::Observable< OperatorBase > | |
| const Observer * | AttachObserver (std::unique_ptr< Observer > observer) |
| std::unique_ptr< Observer > | DetachObserver (const Observer *observer_ptr) |
| Returns a unique_ptr to the removed observer. More... | |
| virtual size_t | NumObservers () |
| void | StartAllObservers () |
| void | StopAllObservers () |
Data Fields | |
| USE_OPERATOR_CONTEXT_FUNCTIONS | |
Static Public Attributes | |
| static constexpr int | kSelfInputs = SparseFused ? 2 : 1 |
| static constexpr int | kNumInputs = Reducer::kInputCount + kSelfInputs |
Static Public Attributes inherited from caffe2::OperatorBase | |
| static constexpr int | kNoNetPositionSet = -1 |
Additional Inherited Members | |
Protected Member Functions inherited from caffe2::Operator< Context > | |
| void | RecordEvent (const char *err_msg=nullptr) final |
| std::string | getErrorMsg () |
Protected Member Functions inherited from caffe2::OperatorBase | |
| DISABLE_COPY_AND_ASSIGN (OperatorBase) | |
Protected Attributes inherited from caffe2::Operator< Context > | |
| Context | context_ |
Protected Attributes inherited from caffe2::OperatorBase | |
| std::unique_ptr< Event > | event_ |
Protected Attributes inherited from caffe2::Observable< OperatorBase > | |
| std::vector< std::unique_ptr< Observer > > | observers_list_ |
Unsorted segment reduction op with optional fused embedding lookup.
Base implementation for UnsortedSegmentXXX and UnsparseSortedSegmentXXX depending on SparseFused static argument.
Unlike the sorted version it allows to have "gaps" in segment ids.
Inputs: 0: DATA - input embedding to do lookups in 1..P: AUX_ARG_ - optional additional arguments to be passed to the reducer, should have the same first dimension as SEGMENT_IDS (e.g. scalars in WeightedSum)
P+1: INDICES - 1-D vector with indices to look up in DATA. Should have the same dimension as SEGMENT_IDS
P+1 or P+2: SEGMENT_IDS - unsorted segment ids 1-D vector
Args: num_segments - allows to override the dimension of the output. If not set it would be inferred from segment_ids tensor.
Output: Tensor with first dimension of K, where K is the max segment id + 1. Rest of dimensions are decided by reducer but usually are the same size as extra dimensions of DATA
Definition at line 994 of file segment_reduction_op.h.
1.8.11