Meta Flow Maps

Microsoft Research · Advanced ·🔢 Mathematical Foundations ·1mo ago

Key Takeaways

Introduces Meta Flow Maps for controlling generative models via one-step posterior sampling

Full Transcript

Hi, thank you for coming. One more week. And today we're very excited to have Peter Bidexip present here in person. His talk will be titled "Meta Flow Maps". And Peter is PhD student at Oxford University advised by EYT and a visiting fellow at Harvard advised by Michael Oberman. Um so with that, thank you for coming and take it away Peter. Thanks. So today I'll be presenting our work "Meta Flow Maps". And then this is joint work with my co-first author Adi. And then our team boss Alvaro. And then my advisors Michael and EY. And the overall focus of our work is on reward alignment via one-step posterior sampling. Okay, so [snorts] I'm going to start off by just giving a brief overview recap of generative modeling. So with transport-based generative modeling, we have some sort of prior distribution P0, like let's say this is standard Gaussian. And then we have a data distribution P1, which we are which is accessible via data set of samples. And basically our goal is to learn to transform samples from P0 into samples from P1, from noise into data samples. And sort of there's been a lot of a lot of different methods for doing this. Diffusion models, flow matching. I'm going to adopt the notation of stochastic interpolants, where essentially we draw a sample I0 from the prior, we draw a data sample I1 from the data set, and then we're just going to interpolate linearly between them from I0 to I1. And for each T, this is going to induce some probability distribution, which I'm just going to denote PT. And basically our goal is to construct an ODE such that if I initialize it at the noise distribution P0, and if I then integrate the ODE over time such that at every time T, the law of my ODE matches the law of my interpolant. So, the law of XT is equal to that PT. And in particular, this means that at time one, we have that my endpoint trajectories of the ODE X1 are samples from the data distribution P1. And that's sort of my overall goal. So now, how do we actually find this drift BT for the ODE? So, we do this using flow matching. So, if we choose the vector field BT of X to be defined by the following conditional expectation, so conditioning on my interpolant being equal to X at time T, if I look at the conditional velocity, the expectation of the conditional velocity, which is just I1 - I0, data minus noise, then you can show that this drift BT of X exactly satisfies the previous conditions. And if I use this drift BT of X, then my terminal samples from my ODE are exactly from the data distribution. And so, we're just going to learn this by regressing a neural network BT hat onto the time derivative, and just by properties of minimizing L2 regression targets, the minimizer is exactly going to be the expected velocity. So, this was just a a quick recap of stochastic interpolants and flow matching. We now have a drift BT of X, which generates samples from the data distribution P1. Now, in practice, if we actually want to use this drift, right, BT of X, that means that we have to numerically integrate the ODE, and that's expensive. Because this means that we have to evaluate the drift many different steps, each one requires a network evaluation, and that adds up. And so, recently there's been a big push for these new types of models, sort of consistency models and flow maps, which basically try to bypass this explicit infinitesimal integration of ODEs, and just learn the solution operators for the ODEs instead that you can jump between different time points of those ODE trajectories. So, in particular, if I have an ODE trajectory from time zero to one, I want my flow map X to satisfy that if I give it the time point XS along that trajectory, then it's going to be able to jump straight to the value at time U of the ODE trajectory in one step. And this means that I can now do, you know, few step generative modeling or even one step. If I just take S is zero, U is one, then I jump straight from noise to data. Okay. So, there've been a lot of really interesting works recently going over different ways of training flow maps and consistency models. For our work, we actually were sort of agnostic to the choice of training method. So, any way that you like to train meta to train flow maps or consistency models will apply with our framework, but I'm just going to give like a really brief overview for people who aren't that familiar with how to train flow maps. And one of the standard ways to train them is to parameterize your flow map in this residual form, where you have the identity function and then you offset it by the velocity. Sort of the where VSU represents the average velocity of your trajectories. And basically, what we're going to try to do is we're going to try to learn this network V over neural networks along with induced map X hat such that it equals the true flow map for the ODE we defined before. And so, we do this via two losses. The first loss basically ensures that the average velocity matches the instantaneous drift in the limit. So, if VSU represents the average velocity from a time point from time S to time U when you start at X, if you look at what happens as S approaches U, so when we are looking at VUU, then it should just equal that infinitesimal instantaneous drift that we had before, right? The BU of X that we got from flow matching. And so this means that if we want to make sure that's true, we can just regress VUU using the same flow matching loss as we just did on the previous slide. Right? And if we minimize this loss, then we know that our average velocity is correct when U when S is equal to U. Okay? And now that we have sort of the correct instantaneous velocity, what we want to do is we want to ensure that it's correct for when S and U are different from each other. So essentially that VSU really represents the average velocity over time intervals. And one equivalent way of enforcing this, which is enforces that the X that we learn is a valid flow map, is by enforcing the semigroup property, which basically says, if I start at some value X at time S, and if I jump to some time W, and then from time W, I keep jumping to time U, that should just be the same as starting at that value at time S and jumping straight to U. And so basically we just want to enforce the semigroup property along with the previous property on the slide on the previous slide, and together these give us a valid flow map. And you can see basically almost all consistency or flow map objectives as enforcing these two conditions. Um and so there've been many many different consistency losses. I'm just going to name a few. Um mean flows, shortcut, Eulerian, Lagrangian, uh terminal velocity matching, all of these essentially are just trying to enforce uh this condition at the end of the day. And just for simplicity, just for concreteness, I'll just write down one loss here, which basically you can see just directly penalizes violations of this rule. So I'm just going to regress the difference of of the left-hand side over the right-hand side, right? And if I minimize this over all times SUW and over all initial query points X, right? The minimizer guarantees that this is equal to zero everywhere. And if this is equal to zero, then we exactly have our semigroup property. So, this was a really, really quick spitfire overview of consistency models and flow maps. Um again, the exact training details aren't too, too important. You can just take your favorite way of training flow maps, but it's more important just this idea that originally what we had was just a drift BT that transported us from P0 to P1, where we had to take many, many steps to get there. And now with flow maps, we compress these ODE trajectories into one-step maps, which allows us to jump straight from noise P0 to data P1. Right. And importantly, with standard flow maps, we're doing this transport where we go from the initial prior P0 to one distribution P1. And just a really slight generalization of this idea is let's say that I'm not just interested in one distribution P1. What if I actually had a family of distributions? And let me just index those distributions with some, you know, context C. So, I have a family of distributions PC over sort of this context set C, and I would actually want to be able to sample from any of these distributions PC. Well, an easy way to do this is to just say, let me add into my flow map an extra input that just takes the uh context or conditioning information C, and once I fix this C, let me just train it as a flow map to transport P0 to this distribution PC. And I'll just enforce this for all of the Cs. So now, all of a sudden, what I have is I have a flow map such that if I put the right context information, I can generate one-shot samples from any of these distributions PC. And this is actually, I think, something that we're all already very familiar with. For example, in the context of class-conditional generation. Right? So, often times we have, let's say, a data set of images, and they're split into classes. So, for example, cats, dog, uh hamburger, right? And what we want to be able to do is we want to be able to sample from that specific subset of the distribution. So, from the images of just cats or the images of just dogs, right? And so, what we do is we just have our, you know, original flow matching model or a flow model, which just takes in the class index C, and then we basically just have a usual flow map transporting the Gaussian P0 to the specific data distribution corresponding to that class. Right? So, we're basically just putting a bunch of flow maps together, conditioning them on some context information. Uh can I ask a question? Sorry, I I must might have missed it. What does the hash tag stand for? Oh, sorry. Uh just the push forward of a distribution. Sorry, yeah. So, what I mean here is if I take a sample from P0 from the Gaussian, and I just plug it into my flow map, so like into the argument here, then now that sample is distributed according to P1. Yeah. Okay, thank you. Okay. So, this is sort of these context-dependent flow maps. And basically, the focus of our talk is we're just going to choose a really specific set of contexts and a specific set of those distributions that we want to target. And I'm later going to explain why we're interested in this specific set of context uh of contexts. So, the set of contexts that we're going to be interested in are time state pairs. So, I'm going to take a time T in 0 1, and a state vector X in RD. And so, just as the same notation as before, now my context vector is this tuple of time and X. And now let me specify what my distribution PC is for that specific context. And it's just going to be the law of I1 conditioned on IT being equal to X, which I'm just going to denote with this P one given T given X. So, just saying that in words, I observe a noisy sample, IT equals X. And I just want to be able to sample what is the distribution of clean data, clean data I1, that's consistent or conditioned on that noisy sample. So, I'm going to explain again in a few minutes why we're actually interested in these sets of distributions, which will be for our reward alignment, but just to give an overview again here. So, imagine that my data distribution is sort of this distribution of images, P1. Right? And let's say that I observe this noisy image. Right? We don't know exactly which clean data sample gave this noisy sample, right? We don't know exactly the I1 that gave this IT. But conditioned on this value, it has some distribution of what are the possible dogs, right? That could have >> [clears throat] >> led to it. Right? It could have come from this dog, or it could have came from this dog. Right? We have some distribution uh arising here, which is exactly what I'm denoting by this P1 given T. And basically, what I want is that if I plug this specific T and X into my map, that now it's going to induce a transport from the Gaussian to that specific posterior distribution. So, if I plug in this noisy dog into my map, I want to be able to generate a bunch of different samples from the posterior distribution of the clean dog given this noisy observation. Right? And in particular, if I take different realizations of my initial noise epsilon, so different values for this initial guy here, but keep X fixed, I'm going to get different observations here, right? So, this noise sample leads to this dog, a different noise sample leads to a different dog. Both are plausible reconstructions of this one. On the other hand, let's say that I observed a different noisy sample. This noisy um flower. Right? It has a completely different distribution of what the clean image is, right? Something over here. Images of uh clean images of flowers. Right? And if I instead put this guy here, this XT, as my context information, I now have a completely different posterior distribution that I'm going to target. It's now the distribution of what was the clean flower that arose from this um noisy uh noisy flower. Right? And so notice here that if I take the same initial noise epsilon, so this first argument is the same, if I choose the context information TX to be this noisy dog, I get a sample of a dog. But if I choose this context information TX to be this noisy flower, then I'm actually going to get a sample of a clean flower. Okay? So basically what we have here is for this infinite collection of posterior distributions indexed by all possible times and noisy values X. We are able to define a transport that takes this Gaussian to any of those posterior distributions. And that's what we're going to call a meta flow map. So now, just as a quick overview of the talk, I'm now going to spend around the next 10 minutes just introducing reward alignment, which is the motivation for why we're interested in meta flow maps, and then I'm going to come back and actually describe meta flow maps in a bit more detail and how we actually might train them. And here I just wanted to give a quick other image just for some more um not motivation, sorry, just so that it's easier to understand what these posterior distributions represent. So in the first column of each one, I have a clean ImageNet image. And then in the second column, I have a noisy realization of that image, where over here I have more noise than I do over here. And basically this is what something that we just sampled with our trained meta flow map. We want to train we want to sample what is the clean image that corresponded to this noisy realization. And you can see that when it's super noisy, right, it's really not clear what the original image is. You sort of have quite a bit of variation. Right, they're still somewhat related, but you still have a lot of variation between them. Right, different colors, for example, here. On the other hand, oh, this And this is a conditional model where you're also conditioning on a prompt, so that prompt already guides a bit. >> Yes. So, in this case, for this specific one, we also have the um class index. Yeah. And so, over here, uh once again, you see that we have these noisy realizations of the clean images, but they're way less noisy, so there's way more structure, and the posterior distribution is going to be much more concentrated on sort of the true image. And so, then you can see that our samples still have a bit of variation, but they're much more similar to each other and to the original sample. Right, and basically, these posterior distributions are exactly what we're going to try to train our meta-flow maps to be able to sample from. Okay. So, now coming back to reward alignment. So, that's going to be our main goal. And the idea here is that we have some pre-trained flow model, either, you know, the flow matching drift BT or a flow map that gives us samples from that data distribution P1. And we somehow want to align these samples to have high rewards for some reward function R(x). So, there are many different reward functions you might consider. Sorry. There are many different reward functions that you might consider. Uh for example, they could arise as likelihoods or coming from inverse problems or just black box neural networks that might measure, you know, how humans how much humans like this image or how much it adheres to a prompt. And we can formalize the goal as trying to sample from this reward tilted distribution. So, it's our original distribution P1 of X tilted by the reward. So, we want to respect the distribution our original distribution, but just favor areas of high reward R of X. And there's sort of two different paradigms or approaches to alignment. Uh one is inference time steering where basically our goal is let's keep the base model fixed. Let's not modify the weights, and let's just modify on the fly during uh sampling time in order to target this new reward uh this new target distribution P reward. Fine-tuning on the other hand says, let's actually update the model parameters to permanently target this uh reward tilted distribution instead of targeting the original distribution P1. Right? And both of these come with challenges. Often times inference time steering is really expensive because you have to modify the dynamics using the pre-trained drift. On the other hand, fine-tuning for every downstream reward is also really expensive. You don't want to have to do that every single time. And so, basically our proposition is if you train this meta flow map at training time with your original model, it's going to be a bit more expensive, but it's actually going to make it so that both inference time steering and fine-tuning is much much more uh is much cheaper. And so, it could actually be very um possible for you to for example fine-tune on any downstream reward. So, despite these operational differences between the two settings, they can be united mathematically with the same solution. So, what we seek now is a new drift BT star where I'm using star to denote that is sort of the optimal drift. And I want to have a new ODE with this drift BT star that I still initialize at my Gaussian P0 such that it uh such that it generates terminal samples from the reward tilted distribution P reward. And so, now the question is what is the optimal vector field BT star? And so, it turns out that we can characterize it explicitly in terms of the value function. So, the value function measures what is the expected future reward of clean data from the current state. So, just to define it here, I'm going to define the value function VT of X as the log of the conditional expectation of the reward of clean data samples I1 conditioned on the noisy observation being equal to X. And I'm just rewriting it here in this notation where I'm sampling X1 exactly from the same posterior distribution, those posterior distributions that I sort of said we're going to be interested in training our meta flow maps for. Right? And so then, if we define this value function, the optimal dynamics, right, the optimal drift BT star is equal to the original drift BT of X plus this correction term, where I have some scalar coefficient times the gradient of the value function. Um it's not too important where it arises from. If you just want a little bit of intuition, many of you are probably familiar with Doob's h-transform. It was with if I have an SDE, I can define the value function for that SDE, and then the optimal control is given by the gradient of that value function. And basically, what we do here is we then just convert that sort of controlled SDE into its corresponding probability flow ODE to give us this ODE. And if you choose the coefficient sigma t correctly, then the value function for the SDE aligns with the value function for interpolant. But it it doesn't really matter too much. All that really matters is that we have this explicit form. We want to learn We want to just take the sum of the original drift BT of X and plus this corrective factor coming from the gradient of the value function, which intuitively is just pointing us into regions where there's going to be high expected terminal reward later on. Okay. Yep. And so basically, in the case when we do inference time steering, since we already have access to beta of x, our goal is just how can we estimate the gradient of the value function and the case when we're doing fine-tuning, we want [clears throat] to design an objective such that we can actually learn it. So now, how can we estimate the gradient of the value function? So I'm just going to write down the gradient the definition of the gradient of the value function again. This was just the value function I wrote before and I have the gradient here. And so if we want to be able to estimate this quantity, it's pretty natural that we want to get posterior samples from these distributions, right? So we want some generative map phi that can efficiently generate samples from those posteriors starting at some noise distribution Q. Q is going to be P0 later. I just wrote it here a bit more generally. Right? And so if we have such a map, right, that transforms noise from Q into those posterior samples, well then we can just write the value function using this reparameterization as the log of the expectation under that noise of the reward under our map. Right, this is just sort of reparameterizing. Right? Now notice that we're actually going to be interested in the gradient of the value function, right? So we want to be able to differentiate this quantity with respect to x. So if you want to be able to differentiate with respect to x and then in particular eventually, you know, exchange the expectation with derivatives, then we'll probably also want our map phi to be differentiable with respect to x. So that's the second condition. We have sort of this conditional transport constraint and then we want phi to be differentiable with respect to x. And if we have the two of them together, then we can get this consistent Monte Carlo estimator of the gradient of the value function. So what we do is we're just going to draw a batch of n samples and IID samples from the noise distribution Q. We're going to push forward all those noise samples to get N posterior samples of sort of the clean data sample condition on that X. We're just going to estimate the value function with this log sum X. And then we're just going to take the gradient of this expression. And again, since our function phi is differentiable with respect to X, this is going to allow us to actually differentiate through this quantity. And this is a consistent Monte Carlo estimator of the exact steering term that we wanted. Right? And so basically what we want to ask now is how can we actually train these maps phi? Right? Once we train these maps phi that satisfy the conditional transport constraint that gives us posterior samples and such that phi is differentiable in X, then we can use this gradient-based estimator of the value of the gradient of the value function for steering. And basically our answer to this is meta flow maps. That's sort of one way that we're going to propose to train it. Now, just to mention, there have been some other works in the past that have explored how to train sort of similar objects. Many of these have used something like MMD, for example, and those are often a bit difficult to tune or scale. And so, you know, building on sort of ideas from the flow map literature that have proved really successful scaling, we decided to sort of approach it via this route. Okay. So, let me give an overview of meta flow maps. So, for each possible context pair T and X, right? We want to sample from this posterior distribution. And for each T and X, we're going to define a conditional auxiliary ODE that transports the prior P0 to that specific posterior. Right? So, we now have this infinite collection of posteriors that we want to target. For each one, we're going to define an ODE that achieves that transport. So, we now have an infinite collection of ODEs. Of course, we could have just learned this ODE, integrated it, but that would be really expensive, and you wouldn't be able to differentiate through it cuz we want it to be differentiable with respect to X. So, what we're going to do instead is we're going to say for each of those auxiliary ODEs, it has some flow map corresponding to it, right? It has a solution operator that's able to jump from time zero to time one of those conditional trajectories. And let's just train all of them in a single model. Okay. And so, why meta in the name? And it's just because when we take the context information TX and plug it into our our model, we're effectively selecting from this infinite collection of flow maps. Okay. So, let's go into this with a bit more detail. So, first I'm going to recall the conditional probability flow ODEs that we're interested in. So, we want an ODE that is going to take data that's going to take noise samples P0 and transport them to posterior samples. And I'm using this bar notation everywhere to denote that it's sort of dependent on the context information TX, and to sort of distinguish it from sort of the previous case that we were looking at, which was unconditional. So, we want a drift B bar S, which, you know, we fix the context information TX and such that if we integrate this ODE, again starting from the reference distribution P0, we're going to end up with posterior samples. And one way of defining this drift is just as the solution to a conditional flow matching problem. Right? So, before we defined the drift as the solution to the flow matching problem, where I start at P0 and transport to the data distribution P1, let's now just replace P1 with that posterior distribution, right? If I can sample from this posterior distribution, I can just do flow matching connecting P0 to this posterior, do flow matching with these sort of conditional samples using the same entropy as before, and then I'll have exactly the drift that is going to transport me from P0 to this posterior. Okay? So, now we sort of defined this infinite collection of auxiliary ODEs that are going to target these posterior distributions. Right? And again, for each of these conditional ODEs, we have some solution operator for it, right? We have some map that should in theory be able to jump from time zero straight to time one without us requiring to infinitesimally integrate this ODE. And basically, what a meta flow map is is is just going to be the collection of all these maps put together in one amortized network. So, more formally, our map X now is going to have inputs S and U, which is sort of this flow map time. It's going to have this first input, which is sort of where the ODE uh input is going to come from. And then we have this context information T and X. Uh so, I just want to be clear that this T and these S and U are completely unrelated. S and U are sort of these inner flow times for these like conditional ODEs that we defined, whereas this T here is just telling us about what is the target distribution or what is the target ODE that we're interested in. But, T has no relation to S and U. Doesn't have to be less or greater. It's just completely unrelated. Okay? And formally, what we want is that if we take if we fix T and X, and if we have the drift B hat tau, and if we take a trajectory for the ODE governed by this drift, so this is one of my conditional ODE trajectories, we want that as a function of the first input is just the flow map. So, it's going to be able to jump from time S of this conditional ODE to time U of the conditional ODE. Okay? And so, I'll just come back to this image one more time. So, we have our data distribute sorry, we have our noise distribution P0. And we have different posterior distributions, right? So, this recall was that posterior distribution of what's the clean dog. And this was our posterior distribution of what's the clean flower. And for each of these posterior distributions, there is some ODE that transports the Gaussian to that posterior. Right? So, for example, in the case of the dog posterior, right? If you look at all of these little red trajectories, right? These are ODE trajectories taking us from the Gaussian to that posterior. On the other hand, if we want to sample from this posterior, we have a bunch of these ODE trajectories, right? That takes us from the Gaussian P0 to this posterior distribution. And each one of these ODEs can be defined as sort of a conditional flow matching problem taking us from P0 to the respective posterior distribution. And now what we're saying is instead of having to simulate these infinitesimal dynamics, instead of having to integrate all these conditional ODEs, let's just compress all of them into efficient one or few step maps, these meta flow maps, that can jump directly from the start point of these conditional ODEs directly to the end points of these conditional ODEs. Okay. So, how do we train meta flow maps now? And it's really going to be as simple as training standard consistency models and flow maps, and we're just going to plug in an extra context vector TX. That's it. So, we're going to parameterize our meta flow map in a similar residual form as we did before. So, we have our sort of identity here, X bar, and then we add on this residual term, where V is going to represent the average velocity of the conditional trajectory. And I'm using the hats here just to note that these are going to be parameterized, that V hat is parameterized as a neural network, and it induces the sort of neural parameterized flow. Now, the realization is if you fix T and X, right? We're in the exact same setting as we were before, right? Fixed on T and X, keeping it fixed, this sort of subnetwork should just be a standard flow map that transports a Gaussian to some specific posterior distribution, and we're in the same residual form as we were before, which means if you just sample over contexts T and X, keep them fixed, you can then apply any flow map or consistency loss to train them. And so, I'll just go over that in a bit more detail. So, remember we have two losses, the diagonal loss and the consistency loss, and the point of the diagonal loss was to enforce that the instantaneous velocity, so the VSS is equal to the true conditional drift, the B bar S. And then we had that consistency loss was to which was to say, we want to propagate the sort of infinitesimal correctness to S and U that are not equal to each other, so that VSU actually represents sort of the average velocity. That basically would define a valid flow map. Okay? So, how do we train the First, we're going to start off with the diagonal loss, so training from data. And basically, we're going to do this by sampling two coupled interpolants. So, first I'm going to draw my prior samples I0 and I0 bar. These are independent. But then I'm going to draw one single data sample I1. I'm going to form my query state, which is just the interpolant, so for some random time T that I choose, I just interpolate between I0 and I1, so this is the same as before. But now, for my auxiliary path, I'm going to use an independent I0 bar from the Gaussian, but the same data point I1 as I used to construct here. So, I'll just say that again. The I1s, the data samples, are going to be the same between my two interpolants, but I'm going to use two two noise vectors, I0 and I0 bar, which are independent. And so, if you do this, notice that conditioned on the value of IT, so suppose that the value of this query state is equal to X. What is the distribution of our clean sample I1? Well, it's exactly from that posterior distribution that we were interested in, right, conditioned on X. That's just the definition of it. And now, since we drew I0 bar to be independent of all of these variables, this means that conditioned on my query state IT being equal to X, this guy here is a sample from P0, and then my I1 is a sample exactly from the posterior that I'm interested in. Which means that this auxiliary path is exactly tracing a path from the prior P0 to the distribution I'm interested in, that specific posterior distribution, right, when I condition on the value of IT. So, basically, now what I have is exactly the flow matching setting. I have my prior sample, I have my data sample, and I can just do flow matching to get the instantaneous velocity. So, more concretely, I'm going to plug in T and IT as my context information. I'm going to plug in my auxiliary uh interpolant I bar S here. This is just the time derivative of this interpolant, right, which will be I1 minus I0 bar. I just do flow matching here. When I minimize this, this guy is just going to be equal to the conditional expectation of the right side, which is exactly the conditional drift that we're interested in. So, basically, to summarize this slide, we now have this loss that's going to enforce that on the diagonal, we correctly have those conditional drifts that we want. And now, we need to propagate this correctness to S and U, where S and U are different from each other, so that we actually have a valid flow map, so that VSU is actually the average velocity of these auxiliary trajectories. Okay, so yeah, so the diagonal lost enforces that we have to correct instantaneous drift. Now we want to ensure we really have a valid flow map, so we want to make sure that we sort of have some consistency, and we can do the same thing as before. Fix T and X and as a function of sort of the first three inputs, right? So now sort of, you know, standard flow map regime, we just want it to satisfy the semigroup condition. And so once again, you can apply any consistency loss that you want. So in our work we even explored a few different ones, for example, mean flows, Lagrangian, Eulerian, um shortcut, sort of a bunch of different losses. For concreteness, I'll just write down the same loss as before. You fix [clears throat] T and X and you just try to enforce the semigroup property on the first three inputs. So you still want that jumping from time S to W and then from time W to U is the same as jumping from S to U. Okay? And so again, minimizing these two losses together enforces that we have a valid meta flow map. That for every context TX, we have a true flow map targeting the specific posterior distribution that we were interested in. Um so that was training from data. You can also train these using distillation. So if the prior P0 is Gaussian, which often we'll assume, and if you already have access to a pre-trained unconditional model, where by unconditional I mean just the drift BT, so no none of the context information, you can actually distill directly from this guy. And basically it uses a result from glass flows, which is a really cool paper by Peter Holl Drive, a different Peter, and basically what he showed in that paper is that the ground truth conditional drift B bar S, so that drift that depends on the context TX, you can actually express it analytically just using our original drift BT. So, it's just some linear combination of X bar X and some reparameterization B. The exact form doesn't matter. What really matters is that this conditional drift B bar S that we're trying to learn is available to us analytically from an unconditional model BT. Okay? And there's just some expression for it. And what this means is that since this is exactly the target that we want for our diagonal term, we don't have to go through this sort of flow matching route if we don't want to. We can just directly regress onto this target. Right? So, we're just going to ensure that if I plug in my context vector T and X, that I get the exact instantaneous drift VSS being equal to B bar S at that fixed vector X bar. And again, I can use I can evaluate this quantity explicitly because I can use the glass flows result, which tells me that I can just reparameterize my unconditional model BT. And the reason that you would want to do this is because now you have no variance targets, right? The minimizer of this loss is exactly zero. It's when the V bar SS is exactly equal to this conditional drift. Unlike flow matching, where you're always going to have variance even at the minimizer, here there's no variance. And you can sample X bar and X from any distribution you want. Okay, and then of course this is how you get the diagonal loss. And then you again just apply a consistency loss to train the off-diagonal. If you have a teacher, there are some tricks you can do to make it more scalable as well. Uh but I'll just leave that for later. Okay. So, now we have these meta flow maps. I've gone quickly over how we can train them. Now, let's actually discuss how we can do inference time steering with meta flow maps. So, just to recall, what was our overall goal? Our goal was we wanted to sample from this reward tilted distribution P1 of X tilted by the exponential reward. And the ODE that we wanted simulate to do this was just the drift that was our unconditional drift BT of X. And then this correction term coming from the gradient of the value function. And this unconditional drift we either assume we already have or if you don't, um I'm not going to go through it, but you can extract it from a meta flow map. So if you train a meta flow map, implicitly you already have this unconditional drift BT of X that you can use. And we used it for all of our experiments. And so really what we need to do is we have to estimate this steering guidance term. Right? And for that we're just going to use the estimator that I wrote down a few slides ago, the following one. Where before I had phi here because it was just like some generic object, whereas now I'm going to instantiate it with our specific meta flow map. So concretely what we do at each step if we want to estimate the gradient of the value function at some specific location X, we're going to generate N samples of noise epsilon I. We're going to push forward them in one step to N posterior samples conditioned on the context. We then estimate the value function using this log sum exp. And then we're just going to take the gradient with respect to X through this whole procedure. And because my meta flow map is differentiable in X, I'm able to just differentiate through it. And importantly, because we define these meta flow maps, right, as one-step maps, not as sort of this iterative integration, it means that back propagating through them is not that expensive. And so now we really have this estimate of the gradient of the value function. And so then we can just plug it in to the Dube ODE from before, take a step, and then repeat this process. Yeah. Um and just to quickly mention, um we primarily explored steering using this um gradient estimator for the Dube ODE. If you have a non-differentiable reward, you can also we also have estimators of the gradient of the value function that doesn't rely on reward gradients, so it would still apply. And then also there's a lot of other ways to use the value function for steering. So there have been people looking at, you know, SMC techniques or Monte Carlo tree search, and we also explored some of those in our work. Yes. How do you handle non-differentiable rewards in in this Um so we just have a different estimator um that implicitly uses Stein's identity. It's similar to like tilt matching. I see. I see. Got it. Okay. Um maybe I'll come back to fine-tuning if we have time. In case I don't, you can also do off-policy unbiased fine-tuning. So you're not just going to regress against, you know, those estimators that were, you know, you know, have some small but non-zero bias. We found a way of making sure that it's unbiased. So if you're interested in that, um I can discuss that later. But then just moving on to the experiment side now. So we also had some evaluations on lower-dimensional experiments like, you know, Gaussian mixture models where we can really quantitatively, you know, using different measurements enforce that we're really learning the right things. But I'll just focus here on our ImageNet experiments. And so for ImageNet experiments, we trained on ImageNet 256 by 256, and we took a standard SIT uh latent diffusion model, and we just added a few adaptations because we have the extra inputs. So usually these models, right, they only take in one time input and one X input, whereas our map, we need three time inputs and two X inputs, right? Um coming back here, we sort of have these extra inputs. And so basically we had to add a few extra parameters in order to allow for all these extra inputs, but if you look, we only added 8 million parameters from a total of 675 million. So it's really just like a 1.3% increase in the number of parameters. So it was really quite light. And so we trained both from data and from distillation. Uh so if you're interested in seeing the results from data, you can also look at the paper, but for sort of our best results was when we took a pre-trained flow map and we just wanted to fine-tune it into a meta flow map. And sort of our pitch is it takes way more work to train the base flow map, and if you then just want to adapt it into a meta flow map, that is actually pretty lightweight compared to the rest of the process. So just to give you a rough estimate, training the original SIT model, so like the flow matching part, took around 800 epochs. Training the flow map part took around 80 epochs, and and this is where we took the checkpoint from, and then training the meta flow map part took additional 30 epochs. Um I will say that one step of sort of the meta flow map training is a bit more expensive than one step of a flow map training, which is a bit more expensive than one step of the base flow matching training. So it's not exactly a fair comparison, but it is way less compute than sort of the rest of the training. And so we took this checkpoint, and you know, you could use it for unconditional sampling. Um now there's not really a reason you'd want to do this. You could probably you should probably just use your base flow map instead, but we wanted to show that even by sort of being able to all of a sudden sample from this infinite collection of posterior distributions, not just one distribution, we still achieved pretty respectable FID scores of 1.97 in two steps. So now let's go on to actually discuss sort of the posterior sampling. So the first group of uh evaluations that we did was we wanted to ensure is our meta flow map actually targeting these posterior distributions on ImageNet. And we did this in two ways. And and these two experiments, by the way, are taken from Glass Flips. So the first experiment we did was we took a bunch of real ImageNet images X1, and we just noise them to some level T to get samples XT. So these are noisy samples now. And then, we took these noisy samples and sampled posterior samples of clean data conditioned on the noisy samples using our MetaFlow map or the baseline. And so, now we sort of have these two groups of samples, the original data samples and our posterior samples. And we just computed the FID between these two groups. And then, the second uh the second metric we used was value function estimation. And here, we said, "What we're really going to be interested in with these posterior samples is computing expectations under them." So, we just said, "Let's estimate the value function using these posterior samples and compare it to a really high-fidelity expensive SDE rollout that we're going to sort of treat as our ground truth and just see how well we do at recovering the value function. And we compared our MetaFlow maps to the glass flows rollout using the teacher model. So, glass flows is using the teacher model, so it's not corrupted by sort of any poor training that we did. And again, we're actually trying to recover the same conditional ODE trajectories as glass flows just in one step. So, we were That's why we decided to compare against glass. And so, really quickly, here's an overview of the results. So, all of the dotted trajectories are glass for different conditioning times. And then, the solid trajectories are MetaFlow maps. And basically, what you can see is that really even for a very small number of floating-point evaluations, MetaFlow maps does a really good job already of preserving of preserving the posteriors as measured by the FID. And once you get to two steps, it basically already plateaus. That's sort of what you would get with closer to 30 steps with glass flows. Right? So, this does mean that we are preserving those posterior distributions. And then, the same thing with the value function. With a small number of floating-point operations, we obviously are going to do much better than trying to roll out ODEs using, you know, just one or two steps, right? And so this is where we really do see a lot of benefit. And the reason why it we are interested in this sort of small number of floating point regime is because that's exactly what we're going to want for gradient estimation. We're not one going to want to do gradient estimation back propagating through, you know, long rollouts of like 10 steps. You really care about this sort of one or two step regime. So, that sort of motivated that our meta flow maps were learning these posterior distributions. And so now we wanted to actually put it into practice and do inference time steering using them. So, for this example, what we did was we did a class conditioned ODE targeting class tabby cat. So, we're already targeting a specific class. And we decided to steer using three separate rewards. So, this is three separate experiments. Image reward, pick score, and AGPS. And for all of them, we used the same prompt, a high-quality, high-resolution photograph of a tabby cat. For those of you who are not as familiar with these three reward functions, basically it takes a clean image and a prompt and it just tells you do humans think that this image looks good and does it align with that prompt? And so basically using this reward sort of biases traject- biases samples to look prettier and sort of adhere to this prompt. And so here were the results. So, these three upper trajectories with these uh dots are sort of what I've been presenting so far, this sort of gradient estimator of the value function. And so you can see and basically if you look at this curve here, these curves here, these are best of n. So, that's where you just roll out a ton of trajectories and just take the best final sample. And what you can see is that in many of these in all these cases, our gradient estimator is beating best of 1,000 and is doing so at over 100 times less compute. Right? And so, I also want to pinpoint that we also have pretty decent scaling properties here. So, if you look at these different points, right, we have um basically the access represents number of floating point operations, and for a gradient estimator, this is dictated by the number of Monte Carlo samples we take. So, the first point is we just take one Monte Carlo sample, two, and then so on and on. And basically, you can see that as we actually add more Monte Carlo samples at each step, we actually do get improving performance as we go. And then, um just to mention, I think I mentioned briefly that we tried out a few other ways for steering. So, for example, uh we have this MFM search variant, which is based on like a search algorithm, which doesn't use reward gradients, and you can see that even this one is able to outperform Best of N as well. And just to give sort of uh some images, just to show that we're actually maybe steering decently and not pure reward hacking, um you can see here that the first columns are base samples from our base MFM, all of a tabby cat. Then, the second column is what we get steering if at each step we use one Monte Carlo sample. So, we're estimating the gradient of the value function with one sample. And then, on the right, we're using 32 samples. And you can sort of see a nice progression, I think, as the number of Monte Carlo samples increases. Um so, I probably don't have time to go into the details of how we do fine-tuning, but just quickly to give some experimental results there. So, with fine-tuning now, we decided to train over all the classes, so we're not restricted to the tabby cat now, we train over all classes, and we just use a generic prompt, a high-quality, high-resolution photograph of a insert whatever class uh the specific um real image samples on. And we train only on HPS, but we plot the performance in terms of the training iterations on all of the three rewards, just so that we can see how, for example, pick score and image reward are changing while we only train on HPS. And also, surprisingly, right, HPS scores are going to increase. But, then it was also nice to see that pick score and image reward are also increasing, suggesting that it's not pure reward hacking. And just to visualize some samples, uh what we have here in the first column is again base samples from our meta photo map. And then, as we progress from left to right, we were just increasing the tilt strength of basically increasing the strength of our reward function. So, these are three separate training runs. They're not related to each other. And just for each run, we have like a higher weight on what the reward function is. And you can see pretty clearly that as we increase the weight, right, we sort of have more and more influence from the reward. Right, eventually it becomes oversaturated, but you can see that, like, in this middle regime, they get sort of nice images um without reward hacking. Okay. So, I think that was all that I wanted to say. Um are there any questions? Something. Open floor to people. And see if anyone online has a question. I have a quick one. Um it seems like this method is closer to deterministic. Do you look at the kind of diversity of samples that you generate from these models and like how that might impact it? So, we didn't really look into that too much. So, you could actually do the same exact steering with SDEs if you wanted to. Um we just presented ODEs just for simplicity, but we also have results for SDEs, and the results were pretty similar. Um in theory, they should give you the same terminal distribution, so there shouldn't be too much of a difference. That may be just a a simple background question. So, um ODEs, even with simple velocity fields, can sometimes have very complex solutions, like the Lorenz equations, for instance. And And in some sense, that's because the solution is described by like iteratively composing the velocity field with itself. And so, this expense of solving this ODE and sometimes does describe real complexity. So, is there the reason in priori why we should expect that the the flow map should be representable with a similar size neural net as we used to represent just the the underlying ODE? Um I'm definitely not an expert in that. I do believe that there's quite a bit of work showing that sort of in general, the evolution going from like a Gaussian to those distributions have nice properties, where the underlying ODE is actually quite well-behaved and amenable to like flow maps training them. And then in practice, it seems to work pretty decently. Um but I agree that like if it was just a generic ODE, that it could be really tough to expect like the flow map to be able to learn anything. But in this case, it's sort of quite nice. Okay, thank you. Um Yes, maybe I'll have a couple of questions. So, uh when when exponential of the reward has high variance, Right. uh So, in the cases where exponential of the reward under the base process has high variance, you would expect that you have to take very large Yes. number of samples. >> And I'm assuming that in this experiment's case, like rewards are pretty dense. So, like you you have like a pretty good signal. Yes. I guess that that also helps and makes things work. So, just to say that Yeah, like basically roughly speaking, we're taking samples from this posterior distribution. And if this is super high variance, you're going to need a lot of samples [clears throat] in order to estimate it well. So, this is more relying on that you have sort of that you don't need to take too many samples. But, I still say even if you don't have meta flow map and it's high variance, that just means you'd have to unroll let's say a ton of trajectories to do it. But, it's still very expensive. And what what's the range of like when you showed the plots and what's the range of the number of samples that you're using or estimating these expectations? Uh so, we go from 1 to 32. Okay. And also for fine-tuning uh similar So, for fine-tuning, we actually don't really Yeah, I'll just say really really quickly. So, this is the optimal drift that we're trying to target. >> Mhm. And so, if you multiply through by the denominator, right? You sort of get this new optimality condition where you no longer have the self-normalized issue where you have a ratio of two estimators, right? You sort of can bring it all under one expectation. And so, this means that you don't we don't really have a notion of number of Monte Carlo samples because you're just sampling over all epsilon T and X. Yeah. So, you just sample one Yeah, you're basically just sampling one at a time. And uh so, one one other point that uh is that when like in in general these rewards that you end up having like image reward or H 2 SP 2, these are in general very non-smooth. So, when you evaluate the reward at the point and then at an image and then the image next to it, you actually get very different uh values. And I was wondering how much of the improvement can be attributed to having a better estimate of the gradient of the value function versus this smoothing effect that you end up having also from averaging out gradients. That's a great question. Yeah. So, um So, for example, here we are only taking one sample. Sure. So, you wouldn't even you wouldn't have a smoothing effect and you're already with one sample beating like a best of 1,000. So, I think that definitely smoothing does help a bit, but I think that here it's more about taking posterior samples. And also, I think it's interesting that it's probably different depending on the reward. Um maybe not as obvious here, but we observe quite different behavior, for example, for image reward compared to the other two. You maybe Yeah, I know I know in particular image reward is very non-smooth. Non-smooth, okay. So, I think that with image reward, we definitely observed some interesting behavior compared to the other two. So, there's definitely something to look into there. Yeah, but it's great. I mean, the two things the two um kind of combined, so yeah. So, these are my questions. Uh so, may I ask a question? Yep. Yeah. Uh so, yeah, like first of all, thanks a lot for your talk. Beautiful work. So, um um here you uh so so you utilize like meta flow matching profile fine-tuning, right? Uh and um so, you have uh this uh unbiased objective where you learn uh vector field. Yeah? Yeah, this one. Uh so, um I'm kind of interested uh like if you have your uh prior um uh prior um flow meta flow map, right? Uh and then you learn the vector field. Can you learn uh not vector field, but uh already a meta flow map? So, yes, for this like beta star. Yeah, that's a great point. Yeah. So, basically, with the fine-tuning objective, the way we presented it is just how do we learn the sort of instantaneous velocity BT of X, but you could easily try to learn the entire flow map, for example, like standard flow map if you wanted to at the same time with self-distillation, or you could do at the same time using do it with a meta flow map. So, you could sort of train whatever you want. Um it's just that sort of the main you'll get more of the technical help from this meta flow map for training sort of this BT of X portion like the instantaneous drift and then the rest is sort of like self-distillation techniques on top of it. Um but you definitely could train any of them sort of at the same time. Yeah, okay. So like you take this bit start and use some other like building flow map technique on top of it, right? Yeah. Uh okay. Thanks. Uh I have a question following your last following the last question. So what if we directly learn the meta flow map to like represent the posterior distribution of a fine-tuned past which I mean P1 given T but X1 is following from the fine-tuned distribution. Yeah, definitely. You could definitely do that. We didn't explore that specific case just because our pitch originally was many people might want to do fine-tuning with like you know their existing model and idea is that like for each downstream reward you might not need that model for long. You just want to get a few samples from like a new reward and so we were really focusing on like really quick adaptation but you definitely could fine-tune it on a tilted reward and that's something [clears throat] that we're actually just trying out now. Cool. Thanks. Yeah, you did a very interesting work and very clear explanation. Thank you. Thanks so much.

Original Description

Controlling generative models—whether via inference-time steering or fine-tuning—is expensive. Control relies on estimating the value function—typically necessitating costly trajectory simulations. To eliminate this bottleneck, we introduce Meta Flow Maps (MFMs), stochastic extensions of consistency models and flow maps. MFMs are trained to perform one-step posterior sampling, generating arbitrarily many i.i.d. draws of clean data x_1 from any noisy state x_t. Crucially, these samples are differentiable in the conditioning state x_t, unlocking efficient estimation of the value function gradient. We leverage this capability to enable both inference-time steering without inner rollouts, and unbiased, off-policy fine-tuning to general rewards. Among our fine-tuning and steering experiments on ImageNet, we highlight that our single-particle steered-MFM sampler outperforms a Best-of-1000 baseline across multiple rewards at a fraction of the compute. Speaker Bio: Peter Potaptchik is a PhD student at Oxford advised by Yee Whye Teh, and a visiting fellow at Harvard advised by Michael S. Albergo. Find seminar details and upcoming talks: https://www.microsoft.com/en-us/research/event/microsoft-research-new-england-generative-modeling-sampling-seminar/
Watch on YouTube ↗ (saves to browser)
Sign in to unlock AI tutor explanation · ⚡30

Related AI Lessons

Up next
How to Open OSM Files (OpenStreetMap Data)
File Extension Geeks
Watch →