I'm performing a binary classification and I want to manually review cases where the model either made an incorrect guess or it was correct but with low confidence. I want the most confident incorrect predictions to appear first, followed by less confident predictions, followed by correct predictions sorted from least to most confident. I want to manually check these examples to see if there's a pattern to the types of examples where the model needs help. My real project involves images created using Stable Diffusion, so I can create more targeted training examples if I see patterns.
Here's a simplified example of my data.
import polars as plpl.DataFrame({"name":["Alice", "Bob", "Caroline", "Dutch", "Emily", "Frank", "Gerald", "Henry", "Isabelle", "Jack"],"truth":[1,0,1,0,1,0,0,1,1,0],"prediction": [1,1,1,0,0,1,0,1,1,0],"confidence": [0.343474,0.298461,0.420634,0.125515,0.772971,0.646964,0.833705,0.837181,0.790773,0.144983]}).with_columns( (1*(pl.col("truth") == pl.col("prediction"))).alias("correct_prediction"))Emily should appear first because she's the highest-confidence incorrect classification. After the other wrong predictions, Dutch should appear next because he has the lowest-confidence correct guess.
| name | truth | prediction | confidence | correct_prediction |
|---|---|---|---|---|
| Emily | 1 | 0 | 0.772971 | 0 |
| Frank | 0 | 1 | 0.646964 | 0 |
| Bob | 0 | 1 | 0.298461 | 0 |
| Dutch | 0 | 0 | 0.125515 | 1 |
| Jack | 0 | 0 | 0.144983 | 1 |
| Alice | 1 | 1 | 0.343474 | 1 |
| Caroline | 1 | 1 | 0.420634 | 1 |
| Isabelle | 1 | 1 | 0.790773 | 1 |
| Gerald | 0 | 0 | 0.833705 | 1 |
| Henry | 1 | 1 | 0.837181 | 1 |
I'm moving from Pandas to Polars and can't figure out how to perform this sort. According to the documentation, you can use an expression with sort(), but it's not clear how I can include an if statement in the expression. I'd also be open to calculating a new sort column and then performing a simple sort() on that, if there's some formula that would do what I want.
I know I could split the DataFrame into correct_predictions and incorrect_predictions, use different sorting logic on each, and then concat() them back together. I'm looking for something more elegant and less messy.