x86/x64 SIMD命令によるAES暗号処理

x86/x64のAESNI命令群の使い方を解説します。

基本

AES暗号については、米国の政府機関が出しているFIPS 197という文書がオリジナルの仕様になりますのでそちらを見ながら読んでください。

ブロックサイズ、キーサイズ

AES暗号ではキーサイズは128ビット、192ビット、256ビットの3種類があり、AES-128、AES-192、AES-256と呼ばれます。

各キーサイズごとに固有の定数Nk(キー長)、Nb(ブロック長)、Nr(ラウンド回数)があり下のように決まっています。NkとNbの単位はwordで、AESでは「word」は32ビットです。


(FIPS 197より抜粋)

煩わしいので以下Nbは4と書きます。

AES暗号ではどのキー長でもブロック長は4ワード(16バイト)固定です。暗号化では16バイトの平文を入力して16バイトの暗号文を得ます。復号では16バイトの暗号文を入力して16バイトの平文を得ます。

state配列

state配列は4*4バイトの二次元配列で、ここに平文を入れたあと、roundと呼ばれる処理を繰り返して加工することで暗号文を作ります。roundはNr回繰り返します。


(FIPS 197より抜粋)

AESNIでは、XMMWORDひとつにこの配列を丸ごと入れて処理します。

AESNI命令の使用の際には、バイトオーダー(エンディアン)の変換を行う必要はありません。

w配列

w配列は、Nkワードのキーをもとに、あらかじめKeyExpansionという処理をして作っておく4 * (Nr + 1)ワードの一次元配列です。round 1回ごとに4ワードずつ使っていきます。roundの周回に入る前に下処理でひとつ使うのでNr+1回分が必要です。

AESNIでは、XMMWORDひとつにround 1回分(4ワード)の要素を丸ごと入れて処理します(AESのワードは32ビットです)。

 

暗号化

以下はFIPS 197に書いてある暗号化のアルゴリズムです。色は解説の都合上私がつけたものです。AESNI命令ではの部分を行います。のAddRoundKey()はxorしているだけなのでpxor命令でできます。

AESENC - AES ENCrypt
AESENCLAST -AES ENCrypt LAST

AESENC xmm1, xmm2/m128   (AESNI
__m128i  _mm_aesenc_si128(__m128i state, __m128i w);
AESENCLAST xmm1, xmm2/m128   (AESNI
__m128i  _mm_aesenclast_si128(__m128i state, __m128i w);

VAESENC xmm1, xmm2, xmm3/m128   (AESNI+(V1
VAESENCLAST xmm1, xmm2, xmm3/m128   (AESNI+(V1

VAESENC ymm1, ymm2, ymm3/m256   (VAES
__m256i  _mm256_aesenc_epi128(__m256i state, __m256i w);
VAESENCLAST ymm1, ymm2, ymm3/m256   (VAES
__m256i  _mm256_aesenclast_epi128(__m256i state, __m256i w);

VAESENC zmm1, zmm2, zmm3/m512   (VAES+(V5
__m512i  _mm512_aesenc_epi128(__m512i state, __m512i w);
VAESENCLAST zmm1, zmm2, zmm3/m512   (VAES+(V5
__m512i  _mm512_aesenclast_epi128(__m512i state, __m512i w);

AESENCは、の4つの関数を呼び出している部分(1ラウンド分)を1命令で処理します。

は最終ラウンドですが、この回だけ呼び出す関数がひとつ少ないので、AESENCLAST命令で処理します。

①にラウンド前のstate配列、②にそのラウンドで使うw配列の要素を入れて実行すると、③にラウンド後のstate配列が返ります。

命令の中で具体的に何をやっているかはFIPS 197の各関数の説明をご覧ください。

w配列(round用キー)を生成する

暗号化の前にKey Expansionで暗号キー(Nkワードのデータ)からw配列を作ります。

以下はFIPS 197のKeyExpansionのアルゴリズムです。

Rcon配列はFIPS 197で定義された方法で計算された定数の配列で、具体的な値は以下のようになります。AES暗号で実際に使われるのは[1]~[10]の範囲だけです。

static const BYTE Rcon[] = {
 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a,
};

keyをw配列にコピーする際に、バイトオーダー(エンディアン)の変換を行う必要はありません。

AESKEYGENASSIST - AES KEY GENeration ASSIST

AESKEYGENASSIST xmm1, xmm2/m128, imm8   (AESNI
VAESKEYGENASSIST xmm1, xmm2/m128, imm8   (AESNI+(V1
__m128i  _mm_aeskeygenassist_si128(__m128i temp, const int Rcon);

AESKEYGENASSISTは、①にtempの値を入れて、②にRcon[i/Nk]の値を指定して実行すると、の2つの式の値を計算して③に返してくれます。

復号

FIPS 197では、復号のアルゴリズムとして、暗号化の手順を逆にしただけの「InvCipher」と、それと等価な結果を得られる「EqInvCipher」という2つが定義されていますが、AESNIがサポートするのは後者のみです。

EqInvCipherでは、w配列の生成処理に後処理をちょっと追加して作ったdw配列を使います。dw配列の作り方は後述します。

AESDEC - AES DECrypt
AESDECLAST -AES DECrypt LAST

AESDEC xmm1, xmm2/m128   (AESNI
__m128i  _mm_aesdec_si128(__m128i state, __m128i dw);
AESDECLAST xmm1, xmm2/m128   (AESNI
__m128i  _mm_aesdeclast_si128(__m128i state, __m128i dw);

VAESDEC xmm1, xmm2, xmm3/m128   (AESNI+(V1
VAESDECLAST xmm1, xmm2, xmm3/m128   (AESNI+(V1

VAESDEC ymm1, ymm2, ymm3/m256   (VAES
__m256i  _mm256_aesdec_epi128(__m256i state, __m256i dw);
VAESDECLAST ymm1, ymm2, ymm3/m256   (VAES
__m256i  _mm256_aesdeclast_epi128(__m256i state, __m256i dw);

VAESDEC zmm1, zmm2, zmm3/m512   (VAES+(V5
__m512i  _mm512_aesdec_epi128(__m512i state, __m512i dw);
VAESDECLAST zmm1, zmm2, zmm3/m512   (VAES+(V5
__m512i  _mm512_aesdeclast_epi128(__m512i state, __m512i dw);

AESDECは、の4つの関数を呼び出している部分(1ラウンド分)を1命令で処理します。

は最終ラウンドですが、この回だけ呼び出す関数がひとつ少ないので、AESDECLAST命令で処理します。

①にラウンド前のstate配列、②にそのラウンドで使うdw配列の要素を入れて実行すると、③にラウンド後のstate配列が返ります。

dw配列を生成する

EqInvCipher用のdw配列は、前述のKeyExpansion処理で作ったw配列に、以下の後処理をすることで作ります。

AESIMC - AES InvMixColumns

AESIMC xmm1, xmm2/m128   (AESNI
VAESIMC xmm1, xmm2/m128   (AESNI+(V1
__m128i  _mm_aesimc_si128(__m128i w);

上のInvMixColumns()の部分を処理する命令です。w配列の要素を①に入れて実行すると、後処理後のdw配列の要素が得られます。

サンプル

AES-128   AES-192   AES-256

#pragma once

#include <intrin.h>

class AES128_NI
{
public:
    // AES128
    static const int Nk = 4;
    static const int Nb = 4;
    static const int Nr = 10;

protected:
    __m128i w128[Nr + 1];
    __m128i dw128[Nr + 1];
    bool decrypting;

public:
    AES128_NI(const unsigned char key[4 * Nk], bool decrypting)
        : decrypting(decrypting)
    {
        KeyExpansion(key);
    }

    void Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);
    void EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);

protected:
    void KeyExpansion(const unsigned char key[4 * Nk]);
};

#include "AES128_NI.h"

//Cipher(byte in[4*Nb], byte out[4*Nb], word w[Nb*(Nr+1)])
//begin
void AES128_NI::Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb])
{
    //  ASSERT(!decrypting);

    //state = in
    __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in));

    //AddRoundKey(state, w[0, Nb-1])
    // just XOR
    state = _mm_xor_si128(state, w128[0]);

    //for round = 1 step 1 to Nr-1
        //SubBytes(state)
        //ShiftRows(state)
        //MixColumns(state)
        //AddRoundKey(state, w[round*Nb, (round+1)*Nb-1])
    //end for
    state = _mm_aesenc_si128(state, w128[1]);
    state = _mm_aesenc_si128(state, w128[2]);
    state = _mm_aesenc_si128(state, w128[3]);
    state = _mm_aesenc_si128(state, w128[4]);
    state = _mm_aesenc_si128(state, w128[5]);
    state = _mm_aesenc_si128(state, w128[6]);
    state = _mm_aesenc_si128(state, w128[7]);
    state = _mm_aesenc_si128(state, w128[8]);
    state = _mm_aesenc_si128(state, w128[9]);

    // The last round
    //SubBytes(state)
    //ShiftRows(state)
    //AddRoundKey(state, w[Nr*Nb, (Nr+1)*Nb-1])
    state = _mm_aesenclast_si128(state, w128[Nr]);

    //out = state
    _mm_storeu_si128(reinterpret_cast<__m128i *>(out), state);

    //end
}

//EqInvCipher(byte in[4*Nb], byte out[4*Nb], word dw[Nb*(Nr+1)])
//begin
void AES128_NI::EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb])
{
    //  ASSERT(decrypting);

    //byte  state[4,Nb]
    //state = in
    __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in));

    //AddRoundKey(state, dw[Nr*Nb, (Nr+1)*Nb-1])
    state = _mm_xor_si128(state, w128[Nr]);

    //for round = Nr-1 step -1 downto 1
        //InvSubBytes(state)
        //InvShiftRows(state)
        //InvMixColumns(state)
        //AddRoundKey(state, dw[round*Nb, (round+1)*Nb-1])
    //end for
    state = _mm_aesdec_si128(state, dw128[9]);
    state = _mm_aesdec_si128(state, dw128[8]);
    state = _mm_aesdec_si128(state, dw128[7]);
    state = _mm_aesdec_si128(state, dw128[6]);
    state = _mm_aesdec_si128(state, dw128[5]);
    state = _mm_aesdec_si128(state, dw128[4]);
    state = _mm_aesdec_si128(state, dw128[3]);
    state = _mm_aesdec_si128(state, dw128[2]);
    state = _mm_aesdec_si128(state, dw128[1]);

    //InvSubBytes(state)
    //InvShiftRows(state)
    //AddRoundKey(state, dw[0, Nb-1])
    state = _mm_aesdeclast_si128(state, dw128[0]);

    //out = state
    _mm_storeu_si128(reinterpret_cast<__m128i *>(out), state);
    //end
}

//static const BYTE Rcon[] = {
//  0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a,
//};

//KeyExpansion(byte key[4*Nk], word w[Nb*(Nr+1)], Nk)
//begin
void AES128_NI::KeyExpansion(const unsigned char key[4 * Nk])
{
    //word  temp
    //i = 0

    //while (i < Nk)
        //w[i] = word(key[4*i], key[4*i+1], key[4*i+2], key[4*i+3])
        //i = i+1
    //end while
    __m128i work = _mm_loadu_si128(reinterpret_cast<const __m128i*>(key));  // work = w3 : w2 : w1 : w0
    w128[0] = work;

    //i = Nk
    //while (i < Nb * (Nr+1)]
        //temp = w[i-1]
        //if (i mod Nk = 0)
            //temp = SubWord(RotWord(temp)) xor Rcon[i/Nk]
        //else if (Nk > 6 and i mod Nk = 4)
            //temp = SubWord(temp)
        //end if
        //w[i] = w[i-Nk] xor temp
        //i = i + 1
    //end while


    __m128i t, t2;

    // f(w) = SubWord(RotWord(w)) xor Rcon

    // work = w3 : w2 : w1 : w0
    // w4 = w0^f(w3)
    // w5 = w1^w4 = w1^w0^f(w3)
    // w6 = w2^w5 = w2^w1^w0^f(w3)
    // w7 = w3^w6 = w3^w2^w1^w0^f(w3)
#define EXPAND(n, RCON) \
    t  = _mm_slli_si128(work, 4);       /* t    = w2                : w1             : w0          : 0        */\
    t  = _mm_xor_si128(work, t);        /* t    = w3^w2             : w2^w1          : w1^w0       : w0       */\
    t2 = _mm_slli_si128(t, 8);          /* t2   = w1^w0             : w0             : 0           : 0        */\
    t  = _mm_xor_si128(t, t2);          /* t    = w3^w2^w1^w0       : w2^w1^w0       : w1^w0       : w0       */\
    work = _mm_aeskeygenassist_si128(work, RCON);           /* work = f(w3) : - : - : - */                      \
    work = _mm_shuffle_epi32(work, 0xFF);/*work = f(w3)             : f(w3)          : f(w3)       ; f(w3)    */\
    work = _mm_xor_si128(t, work);      /* work = w3^w2^w1^w0^f(w3) : w2^w1^w0^f(w3) : w1^w0^f(w3) : w0^f(w3) */\
    w128[n] = work;                     /* work = w7 : w6 : w5 : w4 */

    EXPAND(1, 0x01);

    // Go on...
    EXPAND(2, 0x02);
    EXPAND(3, 0x04);
    EXPAND(4, 0x08);
    EXPAND(5, 0x10);
    EXPAND(6, 0x20);
    EXPAND(7, 0x40);
    EXPAND(8, 0x80);
    EXPAND(9, 0x1b);
    EXPAND(10, 0x36);

    // Additional process for EqInvCipher
    if (decrypting) {
        //for i = 0 step 1 to (Nr+1)*Nb-1
            //dw[i] = w[i]
        //end for
        dw128[0] = w128[0];

        //for round = 1 step 1 to Nr-1
            //InvMixColumns(dw[round*Nb, (round+1)*Nb-1])
        //end for
        dw128[1] = _mm_aesimc_si128(w128[1]);
        dw128[2] = _mm_aesimc_si128(w128[2]);
        dw128[3] = _mm_aesimc_si128(w128[3]);
        dw128[4] = _mm_aesimc_si128(w128[4]);
        dw128[5] = _mm_aesimc_si128(w128[5]);
        dw128[6] = _mm_aesimc_si128(w128[6]);
        dw128[7] = _mm_aesimc_si128(w128[7]);
        dw128[8] = _mm_aesimc_si128(w128[8]);
        dw128[9] = _mm_aesimc_si128(w128[9]);

        dw128[Nr] = w128[Nr];
    }
    //end
}
#pragma once

#include <intrin.h>

class AES192_NI
{
public:
    // AES192
    static const int Nk = 6;
    static const int Nb = 4;
    static const int Nr = 12;

protected:
    __m128i w128[Nr + 1];
    __m128i dw128[Nr + 1];
    bool decrypting;

public:
    AES192_NI(const unsigned char key[4 * Nk], bool decrypting)
        : decrypting(decrypting)
    {
        KeyExpansion(key);
    }

    void Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);
    void EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);

protected:
    void KeyExpansion(const unsigned char key[4 * Nk]);
};

#include "AES192_NI.h"

//Cipher(byte in[4*Nb], byte out[4*Nb], word w[Nb*(Nr+1)])
//begin
void AES192_NI::Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb])
{
    //  ASSERT(!decrypting);

    //state = in
    __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i*>(in));

    //AddRoundKey(state, w[0, Nb-1])
    // just XOR
    state = _mm_xor_si128(state, w128[0]);

    //for round = 1 step 1 to Nr-1
        //SubBytes(state)
        //ShiftRows(state)
        //MixColumns(state)
        //AddRoundKey(state, w[round*Nb, (round+1)*Nb-1])
    //end for
    state = _mm_aesenc_si128(state, w128[1]);
    state = _mm_aesenc_si128(state, w128[2]);
    state = _mm_aesenc_si128(state, w128[3]);
    state = _mm_aesenc_si128(state, w128[4]);
    state = _mm_aesenc_si128(state, w128[5]);
    state = _mm_aesenc_si128(state, w128[6]);
    state = _mm_aesenc_si128(state, w128[7]);
    state = _mm_aesenc_si128(state, w128[8]);
    state = _mm_aesenc_si128(state, w128[9]);
    state = _mm_aesenc_si128(state, w128[10]);
    state = _mm_aesenc_si128(state, w128[11]);

    // The last round
    //SubBytes(state)
    //ShiftRows(state)
    //AddRoundKey(state, w[Nr*Nb, (Nr+1)*Nb-1])
    state = _mm_aesenclast_si128(state, w128[Nr]);

    //out = state
    _mm_storeu_si128(reinterpret_cast<__m128i*>(out), state);

    //end
}

//EqInvCipher(byte in[4*Nb], byte out[4*Nb], word dw[Nb*(Nr+1)])
//begin
void AES192_NI::EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb])
{
    //  ASSERT(decrypting);

    //byte  state[4,Nb]
    //state = in
    __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i*>(in));

    //AddRoundKey(state, dw[Nr*Nb, (Nr+1)*Nb-1])
    state = _mm_xor_si128(state, dw128[Nr]);

    //for round = Nr-1 step -1 downto 1
        //InvSubBytes(state)
        //InvShiftRows(state)
        //InvMixColumns(state)
        //AddRoundKey(state, dw[round*Nb, (round+1)*Nb-1])
    //end for
    state = _mm_aesdec_si128(state, dw128[11]);
    state = _mm_aesdec_si128(state, dw128[10]);
    state = _mm_aesdec_si128(state, dw128[9]);
    state = _mm_aesdec_si128(state, dw128[8]);
    state = _mm_aesdec_si128(state, dw128[7]);
    state = _mm_aesdec_si128(state, dw128[6]);
    state = _mm_aesdec_si128(state, dw128[5]);
    state = _mm_aesdec_si128(state, dw128[4]);
    state = _mm_aesdec_si128(state, dw128[3]);
    state = _mm_aesdec_si128(state, dw128[2]);
    state = _mm_aesdec_si128(state, dw128[1]);

    //InvSubBytes(state)
    //InvShiftRows(state)
    //AddRoundKey(state, dw[0, Nb-1])
    state = _mm_aesdeclast_si128(state, dw128[0]);

    //out = state
    _mm_storeu_si128(reinterpret_cast<__m128i*>(out), state);
    //end
}

//static const BYTE Rcon[] = {
//  0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a,
//};

//KeyExpansion(byte key[4*Nk], word w[Nb*(Nr+1)], Nk)
//begin
void AES192_NI::KeyExpansion(const unsigned char key[4 * Nk])
{
    //word  temp
    //i = 0

    //while (i < Nk)
        //w[i] = word(key[4*i], key[4*i+1], key[4*i+2], key[4*i+3])
        //i = i+1
    //end while
    __m128i work1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(key));         // w3 : w2 : w1 : w0
    __m128i work2 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(key + 16));    // - : - : w5 : w4

    w128[0] = work1;

    //i = Nk
    //while (i < Nb * (Nr+1)]
        //temp = w[i-1]
        //if (i mod Nk = 0)
            //temp = SubWord(RotWord(temp)) xor Rcon[i/Nk]
        //else if (Nk > 6 and i mod Nk = 4)
            //temp = SubWord(temp)
        //end if
        //w[i] = w[i-Nk] xor temp
        //i = i + 1
    //end while

    __m128i t, t2;
    __m128i work3;

    // f(w) = SubWord(RotWord(w)) xor Rcon

    // work1 = w3 : w2 : w1 : w0
    // work2 = - : - : w5 : w4
    // w6 = w0^f(w5)
    // w7 = w1^w6 = w1^w0^f(w5) 
    t = _mm_aeskeygenassist_si128(work2, 0x01); /* t = - : - : f(w5) : -     */
    t = _mm_shuffle_epi32(t, 0x55);     /* t     = f(w5) : f(w5) : f(w5) : f(w5) */
    t2 = _mm_slli_si128(work1, 4);      /* t2    = -     : -     : w0    : 0     */
    t2 = _mm_xor_si128(work1, t2);      /* t2    = -     : -     : w1^w0 : w0    */
    t = _mm_xor_si128(t2, t);           /* t     = -     : -     : w7    : w6    */
    work2 = _mm_unpacklo_epi64(work2, t);   /* work2 = w7    : w6    : w5    : w4    */
    w128[1] = work2;

    // work1 = w3 : w2 : w1 : w0
    // work2 = w7 : w6 : w5 : w4
    // w8 = w2^w7 
    // w9 = w3^w8 = w3^w2^w7
    // w10 = w4^w9 = w4^w3^w2^w7
    // w11 = w5^w10 = w5^w4^w3^w2^w7
#define EXPAND1(n)                                                                           \
    t     = _mm_alignr_epi8(work2, work1, 8);/* t    = w5          : w4       : w3    : w2 */\
    work3 = _mm_shuffle_epi32(work2, 0xFF); /* work3 = w7          : w7       : w7    : w7 */\
    t2    = _mm_slli_si128(t, 4);           /* t2    = w4          : w3       : w2    : 0  */\
    t     = _mm_xor_si128(t, t2);           /* t     = w5^w4       : w4^w3    : w3^w2 : w2 */\
    t2    = _mm_slli_si128(t, 8);           /* t2    = w3^w2       : w2       : 0     : 0  */\
    t2    = _mm_xor_si128(t, t2);           /* t2    = w5^w4^w3^w2 : w4^w3^w2 : w3^w2 : w2 */\
    work3 = _mm_xor_si128(t2, work3);       /* work3 = w11         : w10      : w9    : w8 */\
    w128[n] = work3;

    EXPAND1(2);

    // work2 = w7 : w6 : w5 : w4
    // work3 = w11: w10 : w9 : w8
    // w12 = w6^f(w11)
    // w13 = w7^w12 = w7^w6^f(w11)
    // w14 = w8^w13 = w8^w7^w6^f(w11)
    // w15 = w9^w14 = w9^w8^w7^w6^f(w11)
#define EXPAND2(n, RCON)                                                                          \
    t     = _mm_alignr_epi8(work3, work2, 8);/*t     = w9          : w8       : w7     : w6     */\
    work1 = _mm_aeskeygenassist_si128(work3, RCON);/* work1 = f(w11) : - : - : -                */\
    work1 = _mm_shuffle_epi32(work1, 0xFF); /* work1 = f(w11)      : f(w11)   : f(w11) : f(w11) */\
    t2    = _mm_slli_si128(t, 4);           /* t2    = w8          : w7       : w6     : 0      */\
    t     = _mm_xor_si128(t, t2);           /* t     = w9^w8       : w8^w7    : w7^w6  : w6     */\
    t2    = _mm_slli_si128(t, 8);           /* t2    = w7^w6       : w6       : 0      : 0      */\
    t     = _mm_xor_si128(t, t2);           /* t     = w9^w8^w7^w6 : w8^w7^w6 : w7^w6  : w6     */\
    work1 = _mm_xor_si128(t, work1);        /* work1 = w15         : w14      : w13    : w12    */\
    w128[n] = work1;

    EXPAND2(3, 0x02);

    const __m128i idx1 = _mm_set_epi8(-1, -1, -1, -1, -1, -1, -1, -1, 15, 14, 13, 12, 15, 14, 13, 12);
    const __m128i idx2 = _mm_set_epi8(7, 6, 5, 4, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1);

    // work3 = w11 : w10 : w9 : w8
    // work1 = w15 : w14 : w13 : w12
    // w16 = w10^w15 
    // w17 = w11^w16 = w11^w10^w15
    // w18 = w12^f(w17)
    // w19 = w13^w18 = w13^w12^f(w17)
#define EXPAND3(n, RCON)                                                                      \
    work2 = _mm_alignr_epi8(work1, work3, 8);/*work2 = w13     : w12    : w11     : w10     */\
    t     = _mm_slli_epi64(work2, 32);      /* t     = w12     : 0      : w10     : 0       */\
    work2 = _mm_xor_si128(work2, t);        /* work2 = w13^w12 : w12    : w11^w10 : w10     */\
    t     = _mm_shuffle_epi8(work1, idx1);  /* t     = 0       : 0      : w15     : w15     */\
    work2 = _mm_xor_si128(work2, t);        /* work2 = w13^w12 : w12    : w17     : w16     */\
    t     = _mm_aeskeygenassist_si128(work2, RCON);/*t=-       : -      : f(w17)  : -       */\
    t     = _mm_shuffle_epi8(t, idx2);      /* t     = f(w17)  : f(w17) : 0       : 0       */\
    work2 = _mm_xor_si128(work2, t);        /* work2 = w19     : w18    : w17     : w16     */\
    w128[n] = work2;

    EXPAND3(4, 0x04);

    // Go on...
    EXPAND1(5);         // w20-w23
    EXPAND2(6, 0x08);   // w24-w27
    EXPAND3(7, 0x10);   // w28-w31

    EXPAND1(8);         // w32-w35
    EXPAND2(9, 0x20);   // w36-w39
    EXPAND3(10, 0x40);  // w40-w43

    EXPAND1(11);        // w44-w47
    EXPAND2(12, 0x80);  // w48-w51


    // Additional process for EqInvCipher
    if (decrypting) {
        //for i = 0 step 1 to (Nr+1)*Nb-1
            //dw[i] = w[i]
        //end for
        dw128[0] = w128[0];

        //for round = 1 step 1 to Nr-1
            //InvMixColumns(dw[round*Nb, (round+1)*Nb-1])
        //end for
        dw128[1] = _mm_aesimc_si128(w128[1]);
        dw128[2] = _mm_aesimc_si128(w128[2]);
        dw128[3] = _mm_aesimc_si128(w128[3]);
        dw128[4] = _mm_aesimc_si128(w128[4]);
        dw128[5] = _mm_aesimc_si128(w128[5]);
        dw128[6] = _mm_aesimc_si128(w128[6]);
        dw128[7] = _mm_aesimc_si128(w128[7]);
        dw128[8] = _mm_aesimc_si128(w128[8]);
        dw128[9] = _mm_aesimc_si128(w128[9]);
        dw128[10] = _mm_aesimc_si128(w128[10]);
        dw128[11] = _mm_aesimc_si128(w128[11]);

        dw128[Nr] = w128[Nr];
    }
    //end
}
#pragma once

#include <intrin.h>

class AES256_NI
{
public:
    // AES256
    static const int Nk = 8;
    static const int Nb = 4;
    static const int Nr = 14;

protected:
    __m128i w128[Nr + 1];
    __m128i dw128[Nr + 1];
    bool decrypting;

public:
    AES256_NI(const unsigned char key[4 * Nk], bool decrypting)
        : decrypting(decrypting)
    {
        KeyExpansion(key);
    }

    void Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);
    void EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);

protected:
    void KeyExpansion(const unsigned char key[4 * Nk]);
};

#include "AES256_NI.h"

//Cipher(byte in[4*Nb], byte out[4*Nb], word w[Nb*(Nr+1)])
//begin
void AES256_NI::Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb])
{
    //  ASSERT(!decrypting);

    //state = in
    __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in));

    //AddRoundKey(state, w[0, Nb-1])
    // Just XOR
    state = _mm_xor_si128(state, w128[0]);

    //for round = 1 step 1 to Nr-1
        //SubBytes(state)
        //ShiftRows(state)
        //MixColumns(state)
        //AddRoundKey(state, w[round*Nb, (round+1)*Nb-1])
    //end for
    state = _mm_aesenc_si128(state, w128[1]);
    state = _mm_aesenc_si128(state, w128[2]);
    state = _mm_aesenc_si128(state, w128[3]);
    state = _mm_aesenc_si128(state, w128[4]);
    state = _mm_aesenc_si128(state, w128[5]);
    state = _mm_aesenc_si128(state, w128[6]);
    state = _mm_aesenc_si128(state, w128[7]);
    state = _mm_aesenc_si128(state, w128[8]);
    state = _mm_aesenc_si128(state, w128[9]);
    state = _mm_aesenc_si128(state, w128[10]);
    state = _mm_aesenc_si128(state, w128[11]);
    state = _mm_aesenc_si128(state, w128[12]);
    state = _mm_aesenc_si128(state, w128[13]);

    // The last round
    //SubBytes(state)
    //ShiftRows(state)
    //AddRoundKey(state, w[Nr*Nb, (Nr+1)*Nb-1])
    state = _mm_aesenclast_si128(state, w128[Nr]);

    //out = state
    _mm_storeu_si128(reinterpret_cast<__m128i *>(out), state);

    //end
}

//EqInvCipher(byte in[4*Nb], byte out[4*Nb], word dw[Nb*(Nr+1)])
//begin
void AES256_NI::EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb])
{
    //  ASSERT(decrypting);

    //byte  state[4,Nb]
    //state = in
    __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in));

    //AddRoundKey(state, dw[Nr*Nb, (Nr+1)*Nb-1])
    state = _mm_xor_si128(state, dw128[Nr]);

    //for round = Nr-1 step -1 downto 1
        //InvSubBytes(state)
        //InvShiftRows(state)
        //InvMixColumns(state)
        //AddRoundKey(state, dw[round*Nb, (round+1)*Nb-1])
    //end for
    state = _mm_aesdec_si128(state, dw128[13]);
    state = _mm_aesdec_si128(state, dw128[12]);
    state = _mm_aesdec_si128(state, dw128[11]);
    state = _mm_aesdec_si128(state, dw128[10]);
    state = _mm_aesdec_si128(state, dw128[9]);
    state = _mm_aesdec_si128(state, dw128[8]);
    state = _mm_aesdec_si128(state, dw128[7]);
    state = _mm_aesdec_si128(state, dw128[6]);
    state = _mm_aesdec_si128(state, dw128[5]);
    state = _mm_aesdec_si128(state, dw128[4]);
    state = _mm_aesdec_si128(state, dw128[3]);
    state = _mm_aesdec_si128(state, dw128[2]);
    state = _mm_aesdec_si128(state, dw128[1]);

    //InvSubBytes(state)
    //InvShiftRows(state)
    //AddRoundKey(state, dw[0, Nb-1])
    state = _mm_aesdeclast_si128(state, dw128[0]);

    //out = state
    _mm_storeu_si128(reinterpret_cast<__m128i *>(out), state);
    //end
}

//static const BYTE Rcon[] = {
//  0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a,
//};

//KeyExpansion(byte key[4*Nk], word w[Nb*(Nr+1)], Nk)
//begin
void AES256_NI::KeyExpansion(const unsigned char key[4 * Nk])
{
    //word  temp
    //i = 0

    //while (i < Nk)
        //w[i] = word(key[4*i], key[4*i+1], key[4*i+2], key[4*i+3])
        //i = i+1
    //end while
    __m128i work1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(key));         
    __m128i work2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(key + 16));
    w128[0] = work1;    // work1 = w3 : w2 : w1 : w0
    w128[1] = work2;    // work2 = w7 : w6 : w5 : w4

    //i = Nk
    //while (i < Nb * (Nr+1)]
        //temp = w[i-1]
        //if (i mod Nk = 0)
            //temp = SubWord(RotWord(temp)) xor Rcon[i/Nk]
        //else if (Nk > 6 and i mod Nk = 4)
            //temp = SubWord(temp)
        //end if
        //w[i] = w[i-Nk] xor temp
        //i = i + 1
    //end while

    __m128i t;

    // f(w) = SubWord(RotWord(w)) xor Rcon
    // g(w) = SubWord(w)

    // work1 = w3 : w2 : w1 : w0
    // work2 = w7 : w6 : w5 : w4
    // w8 = w0^f(w7)
    // w9 = w1^w8 = w1^w0^f(w7)
    // w10 = w2^w9 = w2^w1^w0^f(w7)
    // w11 = w3^w11 = w3^w2^w1^w0^f(w7)
#define EXPAND1(n, RCON)                                                                                        \
    t = _mm_slli_si128(work1, 4);   /* t     = w2                : w1             : w0          : 0        */   \
    work1 = _mm_xor_si128(work1, t);/* work1 = w3^w2             : w2^w1          : w1^w0       : w0       */   \
    t = _mm_slli_si128(work1, 8);   /* t     = w1^w0             : w0             : 0           : 0        */   \
    work1 = _mm_xor_si128(work1, t);/* work1 = w3^w2^w1^w0       : w2^w1^w0       : w1^w0       : w0       */   \
    t = _mm_aeskeygenassist_si128(work2, RCON); /* t = f(w7) : - : - : - */                                     \
    t = _mm_shuffle_epi32(t, 0xFF); /* t     = f(w7)             : f(w7)          : f(w7)       : f(w7)    */   \
    work1 = _mm_xor_si128(work1, t);/* work1 = w3^w2^w1^w0^f(w7) : w2^w1^w0^f(w7) : w1^w0^f(w7) : w0^f(w7) */   \
    w128[n] = work1;                /* work1 = w11 : w10 : w9 : w8 */

    EXPAND1(2, 0x01);

    // work2 = w7 : w6 : w5 : w4
    // work1 = w11 : w10 : w9 : w8
    // w12 = w4^g(w11)
    // w13 = w5^w12 = w5^w4^g(w11)
    // w14 = w6^w13 = w6^w5^w4^g(w11)
    // w15 = w7^w14 = w7^w6^w5^w4^g(w11)
#define EXPAND2(n)                                                                                              \
    t = _mm_slli_si128(work2, 4);   /* t     = w6                : w5             : w4          : 0        */   \
    work2 = _mm_xor_si128(work2, t);/* work2 = w7^w6             : w6^w5          : w5^w4       : w4       */   \
    t = _mm_slli_si128(work2, 8);   /* t     = w5^w4             : w4             : 0           : 0        */   \
    work2 = _mm_xor_si128(work2, t);/* work2 = w7^w6^w5^w4       : w6^w5^w4       : w5^w4       : w4       */   \
    t = _mm_aeskeygenassist_si128(work1, 0);    /* t = - : g(w11) : - : - */                                    \
    t = _mm_shuffle_epi32(t, 0xAA); /* t     = g(w11)            : g(w11)         : g(w11)      : g(w11)   */   \
    work2 = _mm_xor_si128(work2, t);/* work2 = w7^w6^w5^w4^g(w11): w6^w5^w4^g(w11): w5^w4^g(w11): w4^g(w11)*/   \
    w128[n] = work2;                /* work2 = w15 : w14 : w13 : w12 */

    EXPAND2(3);

    // Go on...
    EXPAND1(4, 0x02);
    EXPAND2(5);
    EXPAND1(6, 0x04);
    EXPAND2(7);
    EXPAND1(8, 0x08);
    EXPAND2(9);
    EXPAND1(10, 0x10);
    EXPAND2(11);
    EXPAND1(12, 0x20);
    EXPAND2(13);
    EXPAND1(14, 0x40);

    // Additional process for EqInvCipher
    if (decrypting) {
        //for i = 0 step 1 to (Nr+1)*Nb-1
            //dw[i] = w[i]
        //end for
        dw128[0] = w128[0];

        //for round = 1 step 1 to Nr-1
            //InvMixColumns(dw[round*Nb, (round+1)*Nb-1])
        //end for
        dw128[1] = _mm_aesimc_si128(w128[1]);
        dw128[2] = _mm_aesimc_si128(w128[2]);
        dw128[3] = _mm_aesimc_si128(w128[3]);
        dw128[4] = _mm_aesimc_si128(w128[4]);
        dw128[5] = _mm_aesimc_si128(w128[5]);
        dw128[6] = _mm_aesimc_si128(w128[6]);
        dw128[7] = _mm_aesimc_si128(w128[7]);
        dw128[8] = _mm_aesimc_si128(w128[8]);
        dw128[9] = _mm_aesimc_si128(w128[9]);
        dw128[10] = _mm_aesimc_si128(w128[10]);
        dw128[11] = _mm_aesimc_si128(w128[11]);
        dw128[12] = _mm_aesimc_si128(w128[12]);
        dw128[13] = _mm_aesimc_si128(w128[13]);

        dw128[Nr] = w128[Nr];
    }
//end
}

x86/x64 SIMD命令一覧表  フィードバック

ホームページ http://www.officedaytime.com/