Understanding Softmax

A collection of pastel colored billiard balls
Photo by BoliviaInteligente on Unsplash

When running an inference server, you can choose settings like temperature, top-p, and top-k. To understand these values, we really just need an understanding of the softmax activation function. I couldn’t really find one single website that made the reasoning behind softmax click very easily, so I decided to make a blog post about it. Maybe your brain is wired like mine and this will help 🙂

Okay, so let’s start on the ground level. What is softmax, and why do we need it? My “what” explanation is: Softmax is a normalization function for a bunch of numbers such that: All the normalized values are probabilities, and the function is biased, tending to exaggerate the probability of the winners.

Why do we need softmax? As far as my understanding goes, it’s used as the very last step to take the very final set of un-normalized weights for the final output token, and then used to pick a token with a certain amount of randomness (corresponding to the temperature). If you don’t want any randomness, and just want the top token, you can call max(…logits) aka argmax and you don’t need softmax. Thus, softmax is responsible for generating interesting, but still high-quality responses by introducing weighted randomness between high-confidence tokens.

It’s actually unclear to me if softmax is used as an activation function at any point other than the final output. Instead, you could use e.g. ReLU or other simple activation functions, (especially for performance reasons), but I can’t easily figure out what various inference servers use. If I’m wrong I’d love to hear it!

Normalization

Given an array of real numbers, what are some normalization strategies? First off, what do we specifically mean in this context? We could possibly mean just to restrict the values to a certain range (in this case 0-1). We could also mean to transform them into probabilities. If we only cared about the former, we would just shift all the values to be 0-based (think about your vector math – it’s the same vector whether you add or remove a constant value, you’re just shifting the baseline). Then, we can divide every number by the spread (max-min) to get a scaled down version that fits between 0-1.

However, that doesn’t cause the numbers to be a probability. To do that, we have to remember how histograms work. I’ll spare you all the tangent too far down into basic maths, but if you remember that the mean of a series is sum(X1..Xn) / n, then it follows logically that the probability of Xi is (Xi / sum(series))

Weighted normalization

So far, with our histogram-based normalization we’ve created a system that can spit out the next token with a direct probability corresponding to what the transformers generated. And that’s a fine starting point! The only difference between that and softmax is that we transform all of the logits by an exponent, e.g. instead of [0, 72, 11] it becomes [e^0, e^72, e^11] before applying the algorithm from earlier. So the final formula becomes softmax(x) = e^x / sum(e^i1 + e^i2 + ...)

Let’s first focus on the effects of doing this transformation, and then we can come back to why we are doing this at all in the first place, rather than just using the unweighted normalization algorithm.

The first effect… (at least that my brain went to)… is overflow. Handling edge cases in numeric handling is always something you should be thinking about right up front! But yeah, why doesn’t this overflow immediately? Not only are you taking an unbounded number and taking the exponent of it, but you’re taking all of these exponents and summing them together!

Well it turns out that there’s a math party trick to calculating this without overflow. We know the output of this is going to be between 0-1, since the numerator and denominator scale and cancel each other out.. The math party trick isn’t even that exciting 🙂 It just takes advantage of the fact that e^(x-m) = e^x / e^m and we can use that fact to subtract terms (in this case the largest term) in order to make all of the values negative, such that when they are taken the exponent of they stay small. That’s my one line summary, but it’s described much better in the Numerical Computation chapter of the Deep Learning Book, which I found from a few semi-confusing stackoverflow posts

Next, the effect of softmax is to shift the probability such that a few values have more of the weight, and other values have less. For example, consider an array of three logits, [1,2,3]. With an unweighted distribution, the value 3 should have a 50% probability (3/6), but with softmax it becomes .066

A graph showing the differences in probabilities between a softmax weighted distribution, as opposed to an unweighted.The softmax table has probabilities .09, .245, .665The simple normalization bars have values .166, .33, and .5

This is actually really intuitive, because the larger the number, the greater the exponent is going to affect it. For a series of [1,2,3], then e^n = [2.7, 7.4, 20.1]. So the largest numbers are always going to be overweighted in this model. As an aside, the base you choose doesn’t really matter, see this post for details.

The final effect that softmax has is converting everything into a probability. We already know that each term is between [0, 1]. For it to be a probability, all of the terms need to sum to 1. We can prove this with just a few short lines of math. Recall that the formula is softmax(x) = e^x / sum(e^i1 + e^i2 + ...), let’s use Z to represent the overall denominator (the sum of the exponentials). So since we have the formula for the softmax of a single term, then the cumulative softmax is going to be the summation of all of these terms, or sum(e^i1/Z + e^i2 / Z + e^i3/Z) for i from 1 to n. So we can pull Z out from the summation, and that becomes 1/Z * sum (e^i1 + e^i2 + ....). If you look up earlier, we’ve defined Z to be the summation of the exponents, so our formula reduces down to 1/Z * Z = 1

Adjusting the weights with Temperature

We just saw that changing the base from e to another exponent doesn’t really affect the output here. However, if we wanted to change the shape of the probability distribution, we can change the logits before using them as the exponential. Given a constant Temperature value T, the modified softmax formula is softmax(x) = e^(x/T) / sum(e^(i1/T) + e^(i2/T) + …)

We can see the effect of temperature for 2 logits

A graph showing the x axis of temperature, and the y axis of probability for two logits (of values 1 and 2). At the left end of the graph, the values are 100% for the #2 logit, and at the right end of the graph, they very very slowly start to converge to 50%

The dashed lines indicate that the probability ought to be 66% chance for 2, and 33% for 1. However, if the temperature is ~0, then this becomes 100% for 2. As the temperature increases, this slowly converges towards 50/50.

We can see a similar thing happen for 3 logits (of score 1, 2, and 3)

Where a similar thing applies. As the temperature becomes small, the top values take away more and more of the probability. You can se the second value retains a bit more probability, and the smallest items drop away most quickly.

If the temperature is 0, then you get a division by 0 error. So in practice, any code setting the temperature will adjust 0 to mean something like 0.001

But why?

Okay, so now that we know how softmax works, why is it used at all? From my understanding, the primary reason softmax is used is for model training, not necessarily inference. For training, you need a gradient for back-propagation (which softmax provides), and additionally, softmax is well-attuned to helping to minimize cross-entropy (used to ensure the model is being trained for the right answer. Some snippets from around the web:

The Softmax classifier gets its name from the softmax function, which is used to squash the raw class scores into normalized positive values that sum to one, so that the cross-entropy loss can be applied. In particular, note that technically it doesn’t make sense to talk about the “softmax loss”, since softmax is just the squashing function, but it is a relatively commonly used shorthand.
https://cs231n.github.io/linear-classify/#softmax

Cutting off z with P(Y=1|z)=max{0,min{1,z}} yields a zero gradient for z outside of [0,1]
We need a strong gradient whenever the model’s prediction is wrong, because we solve logistic regression with gradient descent. For logistic regression, there is no closed form solution.
The logistic function has the nice property of asymptoting a constant gradient when the model’s prediction is wrong, given that we use Maximum Likelihood Estimation to fit the model
https://stats.stackexchange.com/questions/162988/why-sigmoid-function-instead-of-anything-else/318209#318209

Okay, primarily, when doing the training aspect, softmax combined with cross-entropy loss provides a strong gradient to help learn the correct weights for the model.

However, just because softmax is used for training doesn’t mean it _has_ to be used for inference. In fact, just taking the highest value (argmax) is used in many cases:

If we are using a machine learning model for inference, rather than training it, we might want an integer output from the system representing a hard decision that we will take with the model output, such as to treat a tumor, authenticate a user, or assign a document to a topic. The argmax values are easier to work with in this sense and can be used to build a confusion matrix and calculate the precision and recall of a classifier. https://deepai.org/machine-learning-glossary-and-terms/softmax-layer

Also remember that when setting the temperature to 0, that’s essentially the same as argmax (and we can just take the largest value instead of computing softmax)

The other reason I think softmax is used is the need to smartly add randomness to the generated output. If the model always returns the top token 100% of the time, it is very stilted and only produces one certain type of output. Choosing interesting paths, some amount of the time, allows the model to go down creative or expressive paths without sacrificing accuracy too much. Softmax (combined with temperature) provides a nice algorithm to turn all of the possible choices into probabilities. Since we have to do this work anyways, my best guess is there’s value in using the same softmax algorithm that was used in training. Otherwise, even if the top value is the same in both cases, the training for the 2nd->last weights is now used differently during inference. I couldn’t find any references about this though, so please do let me know if this assumption is correct!

Performance considerations

Both softmax, and using it to predict tokens is expensive. Softmax involves taking an exponential for every term, and then doing division on it. Exponents are not cheap to calculate, compared to e.g. multiplication and division. It’s either done via a taylor series expansion, or other optimizations. This website is 15 years old, but it suggests that pow operations are 100x slower. I don’t know the actual numbers but I’m sure it’s slower than a simple mul instruction.

Second, let’s think about how softmax is used. The numbers being fed into softmax correspond to the output tokens. So this means that there is one value for every single possible output token. Tokenization methods may differ across different models, but we’re talking about tens of thousands of tokens.

So we have 10,000 * expensive exponent and then we have to do that every single output token that’s predicted. Computers are very fast, but still the rest of inference is meant to scale as quickly as possible using basic matrix math that is easy and cheap to parallelize.

Additionally, we also need to sort these values, to use them as probabilities. Let’s think about how this works. We compute softmax, and then we want to pick a token. We ask the computer to choose rand(0, 1), and then… how do we know what that number corresponds to? Well, if we sorted the softmax output, so it looked like e.g. [.82, 0.1, 0.07, 0.01]then we would know that 0->0.8 is the first token, 0.8 to .9 is the second token and so on. So now we also have the O(nlogn) cost to sort these tokens as well.

There are a few ways we can improve the performance, and this is where top_k comes in

Top-k

We have to sort the numbers anyways, and the incoming sort order before applying softmax is going to be the same order as after softmax. So, if we do the sorting before, we’ll already know which values are extremely improbable. We can just exclude them entirely from the softmax calculation. That’s what the top-k value does. So, even if your tokenization algorithm has one million possible output tokens, top-k says “give me the highest 1000 tokens before running softmax”. You still pay the O(nlogn) cost to sort, but not the exponential cost.

Since you are removing tokens, this does have the side-effect of giving even more probability to the largest tokens (since remember those tokens are the greediest). But as long as your top-k is reasonably large (and captures most of the significant total probability) then this isn’t an issue.

Other related parameters

Top-k is the only parameter that affects performance, since it is applied before the softmax calculation. Other parameters are used after softmax, to consider which probabilities should be considered. There’s an in-depth reddit post which provides diagrams, that I suggest reading if you want to learn more.

Top-p: Keep taking probabilities until you reach a required cumulative probability, then drop the remaining tokens. You could represent this in code by changing rand(0,1) to rand(0,top_p) and then making sure the lower probabilities are at the upper end of your range

Min-p: (directly quoted from the reddit post): What Min P is doing is simple: we are setting a minimum value that a token must reach to be considered at all. The value changes depending on how confident the highest probability token is.mSo if your Min P is set to 0.1, that means it will only allow for tokens that are at least 1/10th as probable as the best possible option. If it’s set to 0.05, then it will allow tokens at least 1/20th as probable as the top token, and so on…

Summary

Softmax is a relatively expensive activation function that is typically used as the last step before returning the final token values for each output token. Softmax has the nice feature that the values are between 0-1 and can be interpreted as probabilities. It skews the results to prefer the top choices, and this skewing can be controlled by the temperature. Besides inference, softmax is chosen because it pairs well in the model training stage with cross-entropy loss