B-T simplification of Christiano preference learning
a guest May 16th, 2019 76 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
- Playing with GPT-2 for various things (mostly poetry: https://www.gwern.net/GPT-2 ), I've been thinking about the potential for preference learning and I think the original architecture can be simplified & improved.
- The motivation for the double-critic architecture is that the data being collected from humans is pairwise, and so one trains the critic to predict comparisons. This outside training loop then has an inner G/agent training loop etc. The double training loop is necessary to collect ratings from brand new areas of statespace that the G/agent can newly access, but also, GAN-style, to avoid the D/critic from being too powerful and saturating loss, I get the impression?
- But, just because the input is pairwise doesn't mean that the output must also be pairwise. It could instead be a scalar, with the D/critic performing regression.
- A Bradley-Terry model is extremely simple and easy to estimate on even large samples, and can easily produce cardinal rankings (eg my simple interactive R tool for ranking, https://www.gwern.net/Resorter , just re-estimates the entire B-T model each interaction because it takes like a tenth of a second at most). Each datapoint gets an estimated cardinal ranking in standard deviations of a hypothetical latent Gaussian. The D/critic then is trained to do regression from a single input to the estimated latent variable of quality.
- So the new loop would look like this:
- 1. run off-the-shelf B-T over a dataset of comparisons of datapoints
- 2. extract the estimated latent variables for each datapoint
- 3. until convergence, supervised training of a D/critic NN to predict the latent for each datapoint
- 4. until convergence, RL training of a G/agent NN with the D/critic NN
- 5. extract _n_ new datapoints from the trained G/agent NN and add to the dataset
- 6. run B-T over the augmented dataset
- 7. ask the oracle for a rating of the _m_ datapoints with the largest posterior uncertainty or some proxy thereof like standard error (which will usually be the new datapoints)
- This could have a lot of advantages:
- 1. the D/critic NN is simplified considerably: instead of 3-way classification on a double input, it is just a single input for regression
- 2. more memory-efficient: before, a double input takes up memory, even with tied weights, only to yield a single comparison; in the same space, 2 regression models could be run, each with a different input and target quality rating
- This could be particularly useful if one tries to use a large Transformer model like GPT-2-345M where memory consumption becomes a serious barrier to running it at all... At 345M, we're down to n=1 minibatches, and I'm not sure how it would be possible to run 2 models simultaneously so as to backpropagate the classification error into both 'halves'.
- 3. more D data-efficient: many comparisons will be useless, or for a given pair, they will quickly cease to be informative; a quality rating is informative regardless of what might've been used as a comparison, providing richer feedback on each input (analogous to AlphaZero switching to a regression target)
- - possibly better 'off-policy' learning: related to saturating, a D/critic trained from a corpus (eg initializing a D/critic by taking a dataset of real and GPT-2-generated poems, and labeling all comparisons as victory for the human poem) might destroy G/agent training if it provides only comparison feedback, but
- - better value function/reward signal for any other approach leveraging the NNs (like MCTS over a tree of sequences), too
- - humans or other datasets can supply cardinal ratings directly when those are available
- 4. possibly more D training-efficient: by training comparisons, the D/critic must, implicitly, be learning an equivalent to a quality rating, in order to provide accurate predictions of a human comparison of all possible pairs - but it does so in a very indirect and obfuscated fashion, which will take training on top of the actual difficulty of evaluation
- 5. G more data & training efficient: a richer reward signal for each sample will of course be quite useful for the G
- 5. the quality variable provides an objective loss for understanding training progress (useful for tasks which don't have them, like poetry generation!), which is also interpretable and could be useful outside of the task (eg ranking poems for recommendation or data-cleaning)
- 6. enables active learning via B-T posterior uncertainty without any need to extract uncertainty estimates of any kind from the D/critic NN; human ratings can be acquired more efficiently, or datapoints selectively pulled from a large dataset (eg imagine a huge dump of poems from Project Gutenberg or elsewhere, of wildly varying quality - with a regression style D/critic NN, you can do a single pass over it with the D/critic NN to select the k% highest poems, use the estimate as a pseudo-datapoint, insert into B-T, and ask humans for the most informative comparisons; with a comparison D/critic NN, how to import usefully a large unlabeled corpus is harder to see)
- The main downside I can see:
- - the latent variables are not necessarily 100% stable, as the whole distribution can drift. The B-T estimates a distribution arbitrarily defined as N(0,1); if the B-T sees only very selected datapoints at the beginning, it might be that after G/agent trains enough, the B-T step would be looking at datapoints which are much better than a mean of 0, so there might be new datapoints all the way out at (what used to be) +100SDs, say. This then leads to the B-T estimate the next cycle shifting the mean/SD to restore the conventional N(0,1). So the regression target for the D/critic's predictions of old datapoints may gradually shift over time, precisely because the richer latent variables *don't* saturate the way simple pairwise comparisons would. I believe this would be a minor problem easily solved by training the D/critic NN each iteration, which is necessary just to handle novel datapoints anyway; since improvements will be small each iteration, the retraining should be easily able to keep up.
- - (frequentist) B-T might require more comparisons in order to infer any total order: a datapoint has to be compared with other datapoints which themselves have comparisons if it is to be globally ranked at all, while a comparison D/critic can work with two entirely disjoint sets of comparisons which don't overlap. So it might require either a few more comparisons (in order to connect all datapoints via a chain of comparisons), or switching to ad hoc imputation like imputing a mean of 0 for uncompared datapoints, or fully Bayesian B-T (which uses priors to provide meaningful estimates for all datapoints). This seems like probably a minor drawback---at worst a Bayesian B-T would take a few seconds to run, which is bad for interactive use but minor compared to the slowness of DRL in general.
- All in all, I think this version of preference could be simpler, easier to implement, *and* train faster.
RAW Paste Data