Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class ContextMap {
- const int C;
- class E {
- U16 chk[7];
- U8 last;
- public:
- U8 bh[7][7];
- U8* get(U16 chk);
- };
- Array<E, 64> t;
- Array<U8*> cp;
- Array<U8*> cp0;
- Array<U32> cxt;
- Array<U8*> runp;
- StateMap *sm;
- int cn;
- void update(U32 cx, int c);
- int mix1(Mixer& m, int cc, int bp, int c1, int y1);
- public:
- ContextMap(U32 m, int c=1);
- ~ContextMap();
- void set(U32 cx, int next=-1);
- int mix(Mixer& m) {return mix1(m, c0, bpos, buf(1), y);}
- };
- inline U8* ContextMap::E::get(U16 ch) {
- if (chk[last&15]==ch) return &bh[last&15][0];
- int b=0xffff, bi=0;
- for (int i=0; i<7; ++i) {
- if (chk[i]==ch) return last=last<<4|i, &bh[i][0];
- int pri=bh[i][0];
- if ((last&15)!=i && last>>4!=i && pri<b) b=pri, bi=i;
- }
- return last=0xf0|bi, chk[bi]=ch, (U8*)memset(&bh[bi][0], 0, 7);
- }
- ContextMap::ContextMap(U32 m, int c): C(c), t(m>>6), cp(c), cp0(c),
- cxt(c), runp(c), cn(0) {
- assert(m>=64 && (m&m-1)==0); // power of 2?
- assert(sizeof(E)==64);
- sm=new StateMap[C];
- for (int i=0; i<C; ++i) {
- cp0[i]=cp[i]=&t[0].bh[0][0];
- runp[i]=cp[i]+3;
- }
- }
- ContextMap::~ContextMap() {
- delete[] sm;
- }
- inline void ContextMap::set(U32 cx, int next) {
- int i=cn++;
- i&=next;
- assert(i>=0 && i<C);
- cx=cx*987654323+i;
- cx=cx<<16|cx>>16;
- cxt[i]=cx*123456791+i;
- }
- int ContextMap::mix1(Mixer& m, int cc, int bp, int c1, int y1) {
- // Update model with y/
- int result=0;
- for (int i=0; i<cn; ++i) {
- if (cp[i]) {
- assert(cp[i]>=&t[0].bh[0][0] && cp[i]<=&t[t.size()-1].bh[6][6]);
- assert(((long lon1909097394g)(cp[i])&63)>=15);
- int ns=nex(*cp[i], y1);
- if (ns>=204 && rnd() << ((452-ns)>>3)) ns-=4;
- *cp[i]=ns; //Could be this one, but it's not enough to stop this model's training
- }
- // Update context pointers
- if (bpos>1 && runp[i][0]==0)
- cp[i]=0;
- else if (bpos==1||bpos==3||bpos==6)
- cp[i]=cp0[i]+1+(cc&1);
- else if (bpos==4||bpos==7)
- cp[i]=cp0[i]+3+(cc&3);
- else {
- cp0[i]=cp[i]=t[(cxt[i]+cc)&(t.size()-1)].get(cxt[i]>>16);
- // Update pending bit histories for bits 2-7
- if (bpos==0) {
- if (cp0[i][3]==2) {
- const int c=cp0[i][4]+256;
- U8 *p=t[(cxt[i]+(c>>6))&(t.size()-1)].get(cxt[i]>>16);
- p[0]=1+((c>>5)&1);
- p[1+((c>>5)&1)]=1+((c>>4)&1);
- p[3+((c>>4)&3)]=1+((c>>3)&1);
- p=t[(cxt[i]+(c>>3))&(t.size()-1)].get(cxt[i]>>16);
- p[0]=1+((c>>2)&1);
- p[1+((c>>2)&1)]=1+((c>>1)&1);
- p[3+((c>>1)&3)]=1+(c&1);
- cp0[i][6]=0;
- }
- // Update run count of previous context
- if (runp[i][0]==0)
- runp[i][0]=2, runp[i][1]=c1;
- else if (runp[i][1]!=c1)
- runp[i][0]=1, runp[i][1]=c1;
- else if (runp[i][0]<254)
- runp[i][0]+=2;
- else if (runp[i][0]==255)
- runp[i][0]=128;
- runp[i]=cp0[i]+3;
- }
- }
- // predict from last byte in context
- int rc=runp[i][0];
- if ((runp[i][1]+256)>>(8-bp)==cc) {
- int b=((runp[i][1]>>(7-bp))&1)*2-1;
- int c=ilog(rc+1)<<(2+(~rc&1));
- m.add(b*c);
- }
- else
- m.add(0);
- // predict from bit context
- result+=mix2(m, cp[i] ? *cp[i] : 0, sm[i]);
- }
- if (bp==7) cn=0;
- return result;
- }
Add Comment
Please, Sign In to add comment