Model Drift: Detecting the Faulty Evolution of ML Models
What is Model Drift?
Imagine the scenario where you’re a data scientist, you built this really useful and impactful model and put it into production. Not only that, it actually does well and people are happy with your work – well, that is until a few months, or even just weeks, go by. Then suddenly your model isn’t making the quality predictions it used to.
This is essentially what model drift is – whenever a model starts to perform worse over time. It’s typically caused by external changes in data, resulting in your model making predictions on data it hasn’t seen before, and so of course its performance degrades over time.
Obviously, the main problem with model drift will be poor performance in predictions, but that can result in many things that will have implications later on when deciding how often you want to check for model drift (because it might be pricey). For now, note the possible negative consequences of undetected model drift, which could be decreased customer satisfaction, loss of trust from stakeholders/decision-makers, wasted resources, and probably worst of all: inaccurate medical diagnostics.
Potential Causes
Causes can be split into two main categories: data drift and concept drift. The difference lies in what gets affected, where concept drift is when the properties of the target variable, or the relationship between the target and input variables, have changed. A change in the target variables definition would be the most obvious example of this (what classifies something as spam has changed), but other occurrences might be due to a change in user behavior or business environment. Data drift on the other hand is when the independent/input variables have changed (for example, the distribution of one specific input variable).
We can further split potential causes into the following:
Seasonal: The most natural case of model drift would just be seasonal changes in data. For example, if you were predicting demand for outdoor tools then obviously demand for winter-related gear will go up shortly before the winter season starts.
Sudden: Sometimes big changes happen outside of your control. This might be more obvious to detect, but harder to prepare for since a lot of the time it’ll be an external new development of some kind that affects your model or system. Take the most recent example for instance: ChatGPT. When it got released all of a sudden there was a huge demand for AI and LLM solutions, but how could you have predicted that? You can be sure platforms like Udemy would have highly valued such a prediction so that they could market LLM courses.
Gradual: The last type is when a model degrades slowly over time. A more straightforward occurrence of this would be when your data distribution changes slowly over time. For Example, consider a population drift, which could result in changes in disease prevalence rates, or even economic conditions as market dynamics evolve.
How to Detect Model Drift
It is common to have both data drift and concept drift occur at the same time, so detection methods won’t differ too much. That being said, we can still break up detection methods into a few types, namely statistical methods, performance metrics, and rule-based checks.
Statistical methods: Firstly, you can analyze the data distribution of your most recent data by looking at summary statistics (mean, median, standard deviation, etc.) and judging how it differs from previous data distributions using domain knowledge.
These could be susceptible to noise, however, and more robust techniques like statistical tests may prove to be better. More specifically, tests like the Kolmogorov-Smirnov test or the Chi-square test can determine if the difference in distributions is statistically significant, in that a null hypothesis can be drawn and the p-value will determine if there is model drift.
Distance Metrics: There are a few distance metrics that can help as well. One commonly used metric is the Wasserstein distance (also known as the Earth Mover’s Distance), which quantifies the “cost” of transforming one data distribution into another. This naturally fits well for detecting and measuring data drift, as it captures both the magnitude and structure of differences between distributions. Some other noteworthy metrics are Kullback-Leibler (KL) Divergence and Jensen-Shannon Divergence.
Rule-Based Checks: Finally, simple rulings using domain knowledge or intuitive/common sense are viable strategies as well. Take a spam detection algorithm for example, if normally 10% of emails are detected as spam, but gradually it has reduced to 1%, then data drift has probably occurred. Conversely, something more technical could be based on an evaluation metric, like if the F1 score decreased by a certain number, then you should investigate if model drift occurred.
Resolving Model Drift:
When it comes to fixing model drift, having systems in place beforehand is going to help the most. The best practice for this is to track versions of your model. Moreover, tracking information, or metadata, about your input data is also good (and sometimes overlooked). With information from both data and model versioning you can easily investigate potential issues once model drift is detected.
Now, how exactly you implement this will always depend on the problem and context, but there are many tools to assist you with this (like MLflow). These tools range from simple monitoring scripts to sophisticated automated systems, but once you have them in place and have detected model drift, typically the next step is to retrain your model. Retraining involves updating the model using new data to improve its accuracy and adapt to changes in the most recent data distribution.
Depending on the frequency of model drift and the availability of quality data, integration with CI/CD pipelines and cronjobs may be beneficial in automating the retraining of a model on a consistent or periodic basis. This way, you can trigger events once model drift is detected to automatically update the model (provided the newer version is better). It is crucial to also consider costs when doing this though, in both computational costs and training time, as there is no free lunch.
Aside from simply retraining the model based on the most recent data, some other helpful techniques include data augmentation (make the model more robust to noise), revisiting the feature engineering process (perhaps there is a feature that correlates with the model drift), or implementing ensemble methods (stacking multiple models might help in generalizability).
Conclusion:
To wrap this up, model drift is when a model’s performance degrades over time, and understanding this subject is important in handling this situation before a model is rendered useless (and more importantly has no business value). This post is only scratching the surface of model drift and MLOps though, and the only real way to learn is to see this happen in real-time. So, I encourage you to build a machine learning app, meaning put a model in production with a frontend, get some users, and see how things develop over time. If something goes wrong, then hopefully the information above might help!