-
Notifications
You must be signed in to change notification settings - Fork 341
Add JSD kernel #264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JSD kernel #264
Conversation
@yundai424 btw, should I change jsd to js_div? |
jsd is fne~ |
Added jsd benchmark script |
GPU CI failing on irrelevant tests |
pushed an implementation without inplace operations on input. see #262 (comment). we can decide which one to use in future. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the efforts! Looks good to me in general and I'm assuming the log scale input is for considering log_softmax as input for common cases which is more numerically stable? (One minor useful feature we can add but could be a future pr is the ignore index part (suppose we can provide a extra input label tensor for it but might be too specific since general JSD does not have hard label as input). )
yes, and torch.KLDivLoss also takes input in the log-space. I think it would be better to have similar arguments for users. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the contribution! @qingquansong could you help to create two follow up issues for 1) adding ignore index support for divergence losses, and 2) add general JSD (w/ beta) support? Thanks a lot!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Let me know when the CI test got fixed and I can give a quick shipping stamp. Thanks!
Summary
Resolve #252
Details
JSD
We expect input$X$ and target $Y$ are distributions in log-space, i.e., $X = log Q$ and $Y = log P$ .$P$ and $Q$ is defined as:
Jenson-Shannon Divergence between two distributions
where$M = \frac{1}{2}(P + Q)$ is the average distribution and $KL$ is the Kullback-Leibler divergence.$X = log Q$ and $Y = log P$ , we can simplify JSD expression to:
Given that
We define the point-wise JSD as:
With point-wise JSD, it's easier to implement JSDs with respect to different reduction methods in future.
The only downside is that it creates a torch.float32 tensor with the same shape as input's.
Current implementation is hardcoded to batchmean which is the original JSD definition.
Gradients
Given:
where$Q = e^X$ , $P = e^Y$ , and $M = \frac{1}{2}(e^X + e^Y)$ .
Gradients of$KL(P\ \Vert\ M)$ with respect to $X_i$ :
Gradients of$KL(Q\ \Vert\ M)$ with respect to $X_i$ :
Final gradients of JSD:
Combine the results from two KL divergence terms:
Simplify this to:
We store gradients at X_ptr in forward pass to save memory, then retrieve it through ctx in backward function as cross_entropy does. (inplace)
note: inplace operations on inputs might cause an issue with gradient computation.
Testing Done
With inplace (Storing gradients to inputs)
reduce memory usage by 61.54%


increase speed by 53.64%
Without inplace
reduce memory usage by 53%


increase speed by 61%
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence