-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Model merging
⚠️ This is an experimental feature
Model merging takes several compatible VW models and merges them into a single model that approximately represents all of the models combined. This is will probably never be as effective as a single model trained with all of the data sequentially. However, for situations where it is not feasible to train against all data sequentially the speedup from parallel computation can make a merged model which sees all data potentially more effective than a model trained on a subset of the data.
When using model merging it is important to use --preserve_performance_counters
when loading models to be merged. However, if loading a merged model the counters need to be reset prior to continue training on it. This can be done by writing and reading the model without the --preserve_performance_counters
option.
The API is exposed in multiple places:
vw-merge
CLI tool- Python API
- C++ API
- Java
VowpalWabbitNative.mergeModels
API
The general shape of this API should be consistent across the several locations.
This API will accept a list of VW models loaded as workspaces_to_merge
, and return a unique pointer to a VW::workspace
which is the merged result.
There are two modes which this API can be used in:
- If the models to be merged were trained from scratch
- The models to be merged with were all trained from some common base model
In case one base_workspace
should be a nullptr
. In case two the base_workspace
should be passed as the common base model. This is to ensure that differences from the common base can be considered.
If logger
is passed it is both used as a logger during the duration of the function and it is set as the logger for the produced merged model.
std::unique_ptr<VW::workspace> merge_models(const VW::workspace* base_workspace,
const std::vector<const VW::workspace*>& workspaces_to_merge, VW::io::logger* logger = nullptr);
Generally speaking, merging is a weighted average of all given models based on relative amount of data processed. Values which act as counters are accumulated instead of averaged.
In the case of the GD
reduction, when save_resume
is in use, then the adaptive
values are used to do a per model parameter weighted average. For all other averaged values in a model, the number of examples seen by a model is used for the given weighted average.
If a reduction defines a save_load
function this implies that the reduction has training state which is persisted. Therefore, a rule of thumb is that if a reduction defines save_load
it must also define merge
. A warning will be emitted if any of the reductions in the stack have a save_load
but no merge
and an error will be emitted if the base reduction in a stack has no merge
as it will definitely not work in that case.
The signature of the merge function depends on if the reduction is a base or not. Ideally, all merge
functions would use the non-base reduction signature but since base learners use the weights and other state in VW::Workspace
it is not currently feasible.
using ReductionDataT = void; // ...
// Base reduction
using merge_with_all_fn = void (*)(const std::vector<float>& per_model_weighting, const VW::workspace& base_workspace,
const std::vector<const VW::workspace*>& all_workspaces, const ReductionDataT& base_data,
const std::vector<ReductionDataT*>& all_data, VW::workspace& output_workspace, ReductionDataT& output_data)
// Non-base reduction
using merge_fn = void (*)(const std::vector<float>& per_model_weighting, const ReductionDataT& base_data,
const std::vector<const ReductionDataT*>& all_data, ReductionDataT& output_data)
This is then set on the respective learner builder during construction.
The following is then exposed off of the learner object
void merge(const std::vector<float>& per_model_weighting, const VW::workspace& base_workspace,
const std::vector<const VW::workspace*>& all_workspaces, const base_learner* base_workspaces_learner,
const std::vector<const base_learner*>& all_learners, VW::workspace& output_workspace,
base_learner* output_learner)
- Home
- First Steps
- Input
- Command line arguments
- Model saving and loading
- Controlling VW's output
- Audit
- Algorithm details
- Awesome Vowpal Wabbit
- Learning algorithm
- Learning to Search subsystem
- Loss functions
- What is a learner?
- Docker image
- Model merging
- Evaluation of exploration algorithms
- Reductions
- Contextual Bandit algorithms
- Contextual Bandit Exploration with SquareCB
- Contextual Bandit Zeroth Order Optimization
- Conditional Contextual Bandit
- Slates
- CATS, CATS-pdf for Continuous Actions
- Automl
- Epsilon Decay
- Warm starting contextual bandits
- Efficient Second Order Online Learning
- Latent Dirichlet Allocation
- VW Reductions Workflows
- Interaction Grounded Learning
- CB with Large Action Spaces
- CB with Graph Feedback
- FreeGrad
- Marginal
- Active Learning
- Eigen Memory Trees (EMT)
- Element-wise interaction
- Bindings
-
Examples
- Logged Contextual Bandit example
- One Against All (oaa) multi class example
- Weighted All Pairs (wap) multi class example
- Cost Sensitive One Against All (csoaa) multi class example
- Multiclass classification
- Error Correcting Tournament (ect) multi class example
- Malicious URL example
- Daemon example
- Matrix factorization example
- Rcv1 example
- Truncated gradient descent example
- Scripts
- Implement your own joint prediction model
- Predicting probabilities
- murmur2 vs murmur3
- Weight vector
- Matching Label and Prediction Types Between Reductions
- Zhen's Presentation Slides on enhancements to vw
- EZExample Archive
- Design Documents
- Contribute: