-
Notifications
You must be signed in to change notification settings - Fork 1.9k
What is a learner?
VW works by transforming a problem from one type to another through a set of reductions. This allows us to leverage hardened solutions to existing problems to solve new problems. In VW, this is implemented using a reduction stack, which is a chain of learner
objects. Each learner represents one distinct learning algorithm.
Conceptually, there are two categories of learners: reduction learners and bottom learners. A reduction learner requires at least one learner below it in the stack. We call that learner its base, and the reduction will recursively call into it, using the base learner's output to compute its own result. A bottom learner does not require any further learners below it, and it directly returns a result. The bottom of the reduction stack must be a bottom learner, and all other learners in the stack must be reduction learners.
This document describes how a learner is implemented in the codebase. Learners are defined by a set of types, functions, and a couple of data fields. Learners are created using learner builder objects, which are templated to enforce type consistency. After a learner is created, the learner object itself is fully type-erased, so all learners are the exact same C++ type: class learner
.
A learner has several important data types:
-
DataT
- The type of the data object of this learner. Each learner has its own data object to store internal state. -
ExampleT
- The type of example this reduction expects. Eitherexample
ormulti_ex
. -
label_type_t
- Used for two label types: the type this learner expects examples to have, and the type this learner produces for its base -
prediction_type_t
- Used for two prediction types: the type this learner expects its base to produce, and the type this learner itself produces
Note that DataT
and ExampleT
are template parameters for learner builders. They are type-erased in the learner builder so that the resulting learner object does not reference them. However, label_type_t
and prediction_type_t
are enums, and are used for data fields in the learner object.
The types label_type_t
and prediction_type_t
are used to define a contract between a reduction learner and its base. These properties must be satisfied in order for the reduction stack to work.
- The output label type of a reduction should match the input label type of its base
- The input label type of a reduction should match the output label type of its base
More details, including special cases for bottom learners, are provided on this page: Matching Label and Prediction Types Between Reductions
This is an overview of important fields in the learner
class.
-
std::string _name
- Human-readable name for the learner -
size_t weights
- Describes how many weight vectors are required by this learner. This means that there can essentially be several models that are referenced by this learner. -
size_t increment
- Used along with the per call increment to reference different weight vectors -
bool is_multiline
-true
if the expectedExampleT
ismulti_ex
, otherwisefalse
andExampleT
isexample
- Input and output prediction and label types:
prediction_type_t _output_pred_type
prediction_type_t _input_pred_type
label_type_t _output_label_type
label_type_t _input_label_type
-
std::shared_ptr<void> _learner_data
- The data object for this learner. Note that here it has been type-erased fromDataT
tovoid
. -
std::shared_ptr<learner> _base_learner
- The base of this learner. It points to the learner object immediately below this one in the reduction stack.- As a
shared_ptr
, this gives each reduction ownership of its base learner. Multiple learners are allowed to share the same base, but this is very uncommon. - For bottom learners, this will be
nullptr
because there does not exist a learner below it. - Note that the reduction stack cannot be traversed from bottom to top. You can only go from top to bottom.
- As a
The logic of a learner is implemented by several std::function
objects. For the overwhelming majority of reductions only learn
, predict
, and finish_example
are important.
Functions can be assigned to a learner only via learner builders. The learner builder takes function pointers to fully-typed functions (with DataT
and ExampleT
), and type-erases them. This is done by binding some arguments via lambda-capture so that the resulting function can be stored in the same generic std::function
type for all learners.
Some functions of a learner are auto-recursive. Auto-recursion is where the corresponding function for each learner in the stack is invoked in sequence automatically, without any individual function in the stack having to explicitly call a base learner's function. This is done by the implementation of the learner class itself.
Not all functions are auto-recursive. Some functions will need to explicitly call the functions of its base learner in the stack.
Details for important functions in a learner are provided in the following sections. Note that all function signatures given here are those before type-erasure. When implementing a new learner, the functions you write should have the signatures given below, with DataT
and ExampleT
replaced by your specific data and example types. You will provide a function pointer to the learner builder, which expects fully-typed function pointers as inputs and stores type-erased std::function
objects into the learner.
void(DataT* data);
This is called once by the driver when it starts up. This does not auto-recurse, the definition in the top-most learner will be used.
void(DataT* data, BaseT* base_learner, ExampleT* example);
These three functions are perhaps the most important. They define the core learning process. update
is not commonly used and by default it simply refers to learn
.
These functions will not auto-recurse. However, in nearly all cases, you will want to use the result of the base learner. Thus, you are responsible for implementing a call to the appropriate function in the base learner.
Each example passed to this function implicitly has a valid label_type_t
object associated with it. Additionally, when ExampleT == VW::example
there is an allocated and empty prediction object on the example, and when ExampleT == VW::multi_ex
there is an allocated and empty prediction object on the zeroth example.
When the base learner is called, any examples that are passed to it MUST adhere to the contract described previously. This is a very important requirement that, if broken, causes serious and sometimes hard to find bugs. Your implementations of these functions are responsible for satisfying the contract.
void(DataT* data, BaseT& base, ExampleT* ex, size_t count, size_t step, polyprediction* pred, bool finalize_predictions);
Multipredict makes several predictions using one example. Each call increments the offset, so it is effectively using a different weight vector for each prediction. This is often used internally in reductions but not often used externally.
Multipredict does not need to be defined. By default, the learner implementation will automatically fall back to predict
if multipredict
is undefined.
-
pred
is an array ofcount
number ofpolyprediction
objects. -
step
is the weight increment to be applied per prediction
float(DataT* data, BaseT* base, ExampleT* example);
Does not auto-recurse.
void(vw&, DataT* data, EaxmpleT* ex);
Finish example is called after learn
/predict
and is where the reduction must calculate and report loss as well as free any resources that were allocated for that example. Additionally, the example label and prediction must be returned to a clean slate.
Does not auto-recurse.
void(DataT* data);
Called at the end of a learning pass. This function is autorecursive.
void(DataT* data);
Called once all of the examples are finished being parsed and processed by the reduction stack. This function is autorecursive.
void(DataT* data);
Called as the reduction is being destroyed. However, do note that if the reduction data type DataT
has a destructor, it will be called automatically. So often this function is not necessary. This function is autorecursive.
void(DataT* data, io_buf* model_buffer, bool read, bool text);
This is how the reduction implements serialization and deserialization from a model file. This function is auto recursive.
-
read
is true if a model file is being read and false if the expectation is to write to the buffer -
text
means that a readable model should be written instead of binary
- Home
- First Steps
- Input
- Command line arguments
- Model saving and loading
- Controlling VW's output
- Audit
- Algorithm details
- Awesome Vowpal Wabbit
- Learning algorithm
- Learning to Search subsystem
- Loss functions
- What is a learner?
- Docker image
- Model merging
- Evaluation of exploration algorithms
- Reductions
- Contextual Bandit algorithms
- Contextual Bandit Exploration with SquareCB
- Contextual Bandit Zeroth Order Optimization
- Conditional Contextual Bandit
- Slates
- CATS, CATS-pdf for Continuous Actions
- Automl
- Epsilon Decay
- Warm starting contextual bandits
- Efficient Second Order Online Learning
- Latent Dirichlet Allocation
- VW Reductions Workflows
- Interaction Grounded Learning
- CB with Large Action Spaces
- CB with Graph Feedback
- FreeGrad
- Marginal
- Active Learning
- Eigen Memory Trees (EMT)
- Element-wise interaction
- Bindings
-
Examples
- Logged Contextual Bandit example
- One Against All (oaa) multi class example
- Weighted All Pairs (wap) multi class example
- Cost Sensitive One Against All (csoaa) multi class example
- Multiclass classification
- Error Correcting Tournament (ect) multi class example
- Malicious URL example
- Daemon example
- Matrix factorization example
- Rcv1 example
- Truncated gradient descent example
- Scripts
- Implement your own joint prediction model
- Predicting probabilities
- murmur2 vs murmur3
- Weight vector
- Matching Label and Prediction Types Between Reductions
- Zhen's Presentation Slides on enhancements to vw
- EZExample Archive
- Design Documents
- Contribute: