-
Notifications
You must be signed in to change notification settings - Fork 511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Gaussian Mixture OT #649
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #649 +/- ##
==========================================
+ Coverage 96.71% 96.78% +0.07%
==========================================
Files 86 88 +2
Lines 17148 17502 +354
==========================================
+ Hits 16584 16940 +356
+ Misses 564 562 -2 |
… to reflect PR changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments, but this is very good work.
ot/gmm.py
Outdated
Cs12 = nx.sqrtm(C_s[i]) | ||
Cs12inv = nx.inv(Cs12) | ||
|
||
for j in range(k_t): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
challenge 2 (optional) :here the second loop should be doable with einsum and batch sqrtm too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for both maps, I ended up doing a loop over all nonzero indices in the plan, I don't know if this is better than a smart einsum
.
# i and j, b[i, j] is the translation part | ||
rng = np.random.RandomState(seed) | ||
|
||
A = nx.zeros((k_s, k_t, d, d)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a BIIG tensor!, again mayb can be done with einsum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looping over nonzero entries in the plan is probably faster for this function, since there are at most k_s + k_t - 1
nonzero entries. I'll look into that
ot/plot.py
Outdated
|
||
|
||
def plot1D_mat(a, b, M, title=''): | ||
r""" Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution | ||
def plot1D_mat(a, b, M, title='', a_label='Source distribution', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the big and necessary update but keep in mind that yteh function must keep a very similar behavior by default than teh previous version.
That is the color and source/target should be the same as previous version . It is OK to have parameter that change this default behavior and you can use the values you want in your examples but this update should not change too much the other examples that use the function (except for the fill of the densities it is much better than the previous plot).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made the default behaviour basically identical to the old ones (just quality upgrades, everything is in the same place), and the 'xy' version can be used with plot_style
which defaults to yx. Colours also default to blue and red :)
…nsity backend + gmm_density does the meshgrid stuff
TODO
nx.det
nx.det
ot.plot.plot1D_mat
updateot.plot
nx.sqrtm
ot.gmm
additions (hopefully)Types of changes
ot.gmm
for Gaussian Mixture Model OT, with a GMM plan example and a GMM flow exampleot.gaussian
ot.utils.proj_SDP
for a broadcastable and backend-compatible projection of a symmetric matrix onto the SDP conenx.det
nx.sqrtm
ot.plot.proj1D_mat
with the same API with visual improvementsMotivation and context / Related issue
new features implementing (GMM OT)[https://arxiv.org/pdf/1907.05254] +
ot.plot1D_map
improvementHow has this been tested (if it applies)
pystest
full coverage + two example in the galleryPR checklist