아메숑 2019. 4. 1. 09:12
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#define GetBit(W,n) (((W)>>(n)) &0x01) //7 6 5 4 3 2 1 0
#define ReturnBit(W,n,k) ((GetBit(W,n)<< k)

int m = 0b10011;

int Sbox[4][4] = { {9, 4, 10, 11},
                {13, 1, 8, 5},
                {6, 2, 0, 3},
                {12, 14, 15, 7} };

int ISbox[4][4] = { { 10,5,9,11 },
                { 1,7,8,15},
                { 6,0,2,3 },
                { 12,4,13,14 } };

int multiply(int a,int b)
{
    int mul[4];
    int i;
    int sum =0;
    mul[0] = a;
    for(i=1;i<4;i++)
    {
        mul[i] = mul[i-1]<<(1);
        if(firstbit(mul[i]) == 5)
            mul[i] ^= m;
    }
    for(i=0;i<4;i++)
    {
        if(GetBit(b,i)==1)
        {
            sum ^= mul[i];
        }
    }
    return sum;
}

int NS(int w1)
{
    int x1 = (GetBit(w1,15)<<1)|GetBit(w1,14);
    int y1 = (GetBit(w1,13)<<1)|GetBit(w1,12);
    int x2 = (GetBit(w1,11)<<1)|GetBit(w1,10);
    int y2 = (GetBit(w1,9)<<1)|GetBit(w1,8);
    int x3 = (GetBit(w1,7)<<1)|GetBit(w1,6);
    int y3 = (GetBit(w1,5)<<1)|GetBit(w1,4);
    int x4 = (GetBit(w1,3)<<1)|GetBit(w1,2);
    int y4 = (GetBit(w1,1)<<1)|GetBit(w1,0);
    int result = (Sbox[x1][y1]<<12)|(Sbox[x2][y2]<<8)|(Sbox[x3][y3]<<4)|Sbox[x4][y4];
    return result;
}

int INS(int w1)
{
    int x1 = (GetBit(w1,15)<<1)|GetBit(w1,14);
    int y1 = (GetBit(w1,13)<<1)|GetBit(w1,12);
    int x2 = (GetBit(w1,11)<<1)|GetBit(w1,10);
    int y2 = (GetBit(w1,9)<<1)|GetBit(w1,8);
    int x3 = (GetBit(w1,7)<<1)|GetBit(w1,6);
    int y3 = (GetBit(w1,5)<<1)|GetBit(w1,4);
    int x4 = (GetBit(w1,3)<<1)|GetBit(w1,2);
    int y4 = (GetBit(w1,1)<<1)|GetBit(w1,0);
    int result = (ISbox[x1][y1]<<12)|(ISbox[x2][y2]<<8)|(ISbox[x3][y3]<<4)|ISbox[x4][y4];
    return result;
}

int NS2(int w1)
{
    int x1 = (GetBit(w1,7)<<1)|GetBit(w1,6);
    int y1 = (GetBit(w1,5)<<1)|GetBit(w1,4);
    int x2 = (GetBit(w1,3)<<1)|GetBit(w1,2);
    int y2 = (GetBit(w1,1)<<1)|GetBit(w1,0);
    int result = (Sbox[x1][y1]<<4)|Sbox[x2][y2];
    return result;
}

int MC(int a)
{
    int s00=a>>12;
    int s10=a>>8 & 0xf;
    int s01=a>>4 & 0xf;
    int s11=a & 0xf;
    int ss00=s00 ^ multiply(4,s10);
    int ss10= multiply(4,s00) ^ s10;
    int ss01= s01 ^ multiply(4,s11);
    int ss11= multiply(4,s01) ^ s11;
    return (ss00<<12) | (ss10<<8) | (ss01<<4) | ss11;
}

int IMC(int a)
{
    int s00=a>>12;
    int s10=a>>8 & 0xf;
    int s01=a>>4 & 0xf;
    int s11=a & 0xf;
    int ss00=multiply(9,s00) ^ multiply(2,s10);
    int ss10= multiply(2,s00) ^ multiply(9,s10);
    int ss01= multiply(9,s01) ^ multiply(2,s11);
    int ss11= multiply(2,s01) ^ multiply(9,s11);
    return (ss00<<12) | (ss10<<8) | (ss01<<4) | ss11;
}

int SR(int a)
{
    int s00=a>>12;
    int s10=a>>8 & 0xf;
    int s01=a>>4 & 0xf;
    int s11=a & 0xf;
    return (s00<<12) | s10 | (s01<<4) | (s11<<8);
}

int firstbit(int f)
{
    int i;
    for(i=5;i>0;i--)
    {
        if(GetBit(f,i-1) == 1)
            break;
    }
    return i;
}

int mod(int a, int b)
{
    int nmg=a;
    int q =0;
    int mv;
    while(firstbit(a)>=firstbit(b))
    {
        mv = firstbit(a)-firstbit(b);
        q = q ^ (1<<(mv));
        int t = b*pow(2,mv);
        nmg = a^t;
        a = nmg;
    }
    return nmg;
}

int KE(int key,int i)
{
    int w1 = key & 0xff;
    w1 = w1>>4 | (w1&0xf)<<4;
    w1 = NS2(w1);
    int x = 4 << i ;
    if (firstbit(x)>4)
        x= mod(x,m);
    x= x<<4;
    w1 = w1^x;
    int w2 = w1^(key>>8);
    int w3 = w2^(127&key);
    w1 = (w2<<8)|w3;
    return w1;
}

void main()
{
    int plain = 0x6F6B;
    int key= 0xA73B;
    int key1 = KE(key,1);
    printf("\nKEY1 : %x",key1);
    int key2 = KE(key1,2);
    printf("\nKEY2 : %x",key2);

    plain ^=key; //AK
    printf("\nAK : %x",plain);
    plain =NS(plain);
    printf("\nNS : %x",plain);
    plain =SR(plain); //SR
    printf("\nSR : %x",plain);
    plain =MC(plain);
    printf("\nMC : %x",plain);
    plain ^=key1; //AK
    printf("\nAK : %x",plain);
    plain =NS(plain);
    printf("\nNS : %x",plain);
    plain =SR(plain); //SR
    printf("\nSR : %x",plain);
    plain ^=key2; //AK
    printf("\nAK : %x",plain);
    plain ^=key2; //AK
    printf("\nAK : %x",plain);
    plain =SR(plain); //SR
    printf("\nSR : %x",plain);
    plain =INS(plain);
    printf("\nINS : %x",plain);
    plain ^=key1; //AK
    printf("\nAK : %x",plain);
    plain =IMC(plain);
    printf("\nIMC : %x",plain);
    plain =SR(plain); //SR
    printf("\nSR : %x",plain);
    plain =INS(plain);
    printf("\nINS : %x",plain);
    plain ^=key; //AK
    printf("\nAK : %x",plain);
}