Guest User

ContextMap Class

a guest
Oct 6th, 2016
130
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.31 KB | None | 0 0
  1. class ContextMap {
  2.   const int C;
  3.   class E {
  4.     U16 chk[7];
  5.     U8 last;
  6.   public:
  7.     U8 bh[7][7];
  8.     U8* get(U16 chk);
  9.   };
  10.   Array<E, 64> t;
  11.   Array<U8*> cp;
  12.   Array<U8*> cp0;
  13.   Array<U32> cxt;
  14.   Array<U8*> runp;
  15.   StateMap *sm;
  16.   int cn;
  17.   void update(U32 cx, int c);
  18.   int mix1(Mixer& m, int cc, int bp, int c1, int y1);
  19. public:
  20.   ContextMap(U32 m, int c=1);
  21.   ~ContextMap();
  22.   void set(U32 cx, int next=-1);
  23.   int mix(Mixer& m) {return mix1(m, c0, bpos, buf(1), y);}
  24. };
  25.  
  26. inline U8* ContextMap::E::get(U16 ch) {
  27.   if (chk[last&15]==ch) return &bh[last&15][0];
  28.   int b=0xffff, bi=0;
  29.   for (int i=0; i<7; ++i) {
  30.     if (chk[i]==ch) return last=last<<4|i, &bh[i][0];
  31.     int pri=bh[i][0];
  32.     if ((last&15)!=i && last>>4!=i && pri<b) b=pri, bi=i;
  33.   }
  34.   return last=0xf0|bi, chk[bi]=ch, (U8*)memset(&bh[bi][0], 0, 7);
  35. }
  36.  
  37. ContextMap::ContextMap(U32 m, int c): C(c), t(m>>6), cp(c), cp0(c),
  38.     cxt(c), runp(c), cn(0) {
  39.   assert(m>=64 && (m&m-1)==0);  // power of 2?
  40.   assert(sizeof(E)==64);
  41.   sm=new StateMap[C];
  42.   for (int i=0; i<C; ++i) {
  43.     cp0[i]=cp[i]=&t[0].bh[0][0];
  44.     runp[i]=cp[i]+3;
  45.   }
  46. }
  47.  
  48. ContextMap::~ContextMap() {
  49.   delete[] sm;
  50. }
  51.  
  52. inline void ContextMap::set(U32 cx, int next) {
  53.   int i=cn++;
  54.   i&=next;
  55.   assert(i>=0 && i<C);
  56.   cx=cx*987654323+i;
  57.   cx=cx<<16|cx>>16;
  58.   cxt[i]=cx*123456791+i;
  59. }
  60.  
  61. int ContextMap::mix1(Mixer& m, int cc, int bp, int c1, int y1) {
  62.   // Update model with y/
  63.   int result=0;
  64.   for (int i=0; i<cn; ++i) {
  65.  
  66.     if (cp[i]) {
  67.       assert(cp[i]>=&t[0].bh[0][0] && cp[i]<=&t[t.size()-1].bh[6][6]);
  68.       assert(((long lon1909097394g)(cp[i])&63)>=15);
  69.       int ns=nex(*cp[i], y1);
  70.       if (ns>=204 && rnd() << ((452-ns)>>3)) ns-=4;
  71.       *cp[i]=ns;  //Could be this one, but it's not enough to stop this model's training
  72.     }
  73.     // Update context pointers
  74.     if (bpos>1 && runp[i][0]==0)
  75.       cp[i]=0;
  76.     else if (bpos==1||bpos==3||bpos==6)
  77.       cp[i]=cp0[i]+1+(cc&1);
  78.     else if (bpos==4||bpos==7)
  79.       cp[i]=cp0[i]+3+(cc&3);
  80.     else {
  81.       cp0[i]=cp[i]=t[(cxt[i]+cc)&(t.size()-1)].get(cxt[i]>>16);
  82.  
  83.       // Update pending bit histories for bits 2-7
  84.      
  85.       if (bpos==0) {
  86.         if (cp0[i][3]==2) {
  87.           const int c=cp0[i][4]+256;
  88.           U8 *p=t[(cxt[i]+(c>>6))&(t.size()-1)].get(cxt[i]>>16);
  89.           p[0]=1+((c>>5)&1);
  90.           p[1+((c>>5)&1)]=1+((c>>4)&1);
  91.           p[3+((c>>4)&3)]=1+((c>>3)&1);
  92.           p=t[(cxt[i]+(c>>3))&(t.size()-1)].get(cxt[i]>>16);
  93.           p[0]=1+((c>>2)&1);
  94.           p[1+((c>>2)&1)]=1+((c>>1)&1);
  95.           p[3+((c>>1)&3)]=1+(c&1);
  96.           cp0[i][6]=0;
  97.         }
  98.    // Update run count of previous context
  99.         if (runp[i][0]==0)
  100.           runp[i][0]=2, runp[i][1]=c1;
  101.         else if (runp[i][1]!=c1)
  102.           runp[i][0]=1, runp[i][1]=c1;
  103.         else if (runp[i][0]<254)
  104.           runp[i][0]+=2;
  105.         else if (runp[i][0]==255)
  106.           runp[i][0]=128;
  107.         runp[i]=cp0[i]+3;
  108.       }
  109.    
  110.     }
  111.     // predict from last byte in context
  112.     int rc=runp[i][0];
  113.     if ((runp[i][1]+256)>>(8-bp)==cc) {
  114.       int b=((runp[i][1]>>(7-bp))&1)*2-1;
  115.       int c=ilog(rc+1)<<(2+(~rc&1));
  116.       m.add(b*c);
  117.     }
  118.     else
  119.       m.add(0);
  120.  
  121.     // predict from bit context
  122.     result+=mix2(m, cp[i] ? *cp[i] : 0, sm[i]);
  123.   }
  124.   if (bp==7) cn=0;
  125.   return result;
  126. }
Add Comment
Please, Sign In to add comment