x86/x64 SIMD命令によるAES暗号処理

x86/x64のAESNI命令群の使い方を解説します。

基本

AES暗号については、米国の政府機関が出しているFIPS 197という文書がオリジナルの仕様になりますのでそちらを見ながら読んでください。

ブロックサイズ、キーサイズ

AES暗号ではキーサイズは128ビット、192ビット、256ビットの3種類があり、AES-128、AES-192、AES-256と呼ばれます。

各キーサイズごとに固有の定数Nk(キー長)、Nb(ブロック長)、Nr(ラウンド回数)があり下のように決まっています。NkとNbの単位はwordで、AESでは「word」は32ビットです。


(FIPS 197より抜粋)

煩わしいので以下Nbは4と書きます。

AES暗号ではどのキー長でもブロック長は4ワード(16バイト)固定です。暗号化では16バイトの平文を入力して16バイトの暗号文を得ます。復号では16バイトの暗号文を入力して16バイトの平文を得ます。

state配列

state配列は4*4バイトの二次元配列で、ここに平文を入れたあと、roundと呼ばれる処理を繰り返して加工することで暗号文を作ります。roundはNr回繰り返します。


(FIPS 197より抜粋)

AESNIでは、XMMWORDひとつにこの配列を丸ごと入れて処理します。

w配列

w配列は、Nkワードのキーをもとに、あらかじめKeyExpansionという処理をして作っておく4 * (Nr + 1)ワードの一次元配列です。round 1回ごとに4ワードずつ使っていきます。roundの周回に入る前に下処理でひとつ使うのでNr+1回分が必要です。

AESNIでは、XMMWORDひとつにround 1回分(4ワード)の要素を丸ごと入れて処理します(AESのワードは32ビットです)。

 

暗号化

以下はFIPS 197に書いてある暗号化のアルゴリズムです。色は解説の都合上私がつけたものです。AESNI命令ではの部分を行います。のAddRoundKey()はxorしているだけなのでpxor命令でできます。

AESENC - AES ENCrypt
AESENCLAST -AES ENCrypt LAST

AESENC xmm1, xmm2/m128
__m128i  _mm_aesenc_si128(__m128i state, __m128i w);
AESENCLAST xmm1, xmm2/m128
__m128i  _mm_aesenclast_si128(__m128i state, __m128i w);

VAESENC xmm1, xmm2, xmm3/m128
VAESENCLAST xmm1, xmm2, xmm3/m128

AESENCは、の4つの関数を呼び出している部分(1ラウンド分)を1命令で処理します。

は最終ラウンドですが、この回だけ呼び出す関数がひとつ少ないので、AESENCLAST命令で処理します。

①にラウンド前のstate配列、②にそのラウンドで使うw配列の要素を入れて実行すると、③にラウンド後のstate配列が返ります。

命令の中で具体的に何をやっているかはFIPS 197の各関数の説明をご覧ください。

サンプル

上のPseudo CodeをC++とSIMDのintrinsicsに置き換えてみたサンプルです。Visual Studio 2015で動作確認しました。

//Cipher(byte in[4*Nb], byte out[4*Nb], word w[Nb*(Nr+1)])
//begin
void AES_SIMD::Cipher(const BYTE in[4 * Nb], BYTE out[4 * Nb])
{
//  ASSERT(!decrypting);

    //state = in
    __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in));

    //AddRoundKey(state, w[0, Nb-1])
    // xor演算するだけ
    state = _mm_xor_si128(w.w128[0], state);

    //for round = 1 step 1 to Nr?1
    for (int round = 1; round <= Nr - 1; round++) {
        //SubBytes(state)
        //ShiftRows(state)
        //MixColumns(state)
        //AddRoundKey(state, w[round*Nb, (round+1)*Nb-1])
        state = _mm_aesenc_si128(state, w.w128[round]);

    //end for
    }

    // 最後のラウンド
    //SubBytes(state)
    //ShiftRows(state)
    //AddRoundKey(state, w[Nr*Nb, (Nr+1)*Nb-1])
    state = _mm_aesenclast_si128(state, w.w128[Nr]);

    //out = state
    _mm_storeu_si128(reinterpret_cast<__m128i *>(out), state);

//end
}
class AES_SIMD
{
public:
    // AES128
//  static const int Nk = 4;
//  static const int Nb = 4;
//  static const int Nr = 10;

    // AES192
//  static const int Nk = 6;
//  static const int Nb = 4;
//  static const int Nr = 12;

    // AES256
    static const int Nk = 8;
    static const int Nb = 4;
    static const int Nr = 14;

public:
    typedef unsigned char BYTE;
    typedef unsigned long AESWORD;

protected:
    union W {
        __m128i w128[Nr + 1];
        AESWORD w32[Nb * (Nr + 1)];
    };
    W w;
    W dw;
    bool decrypting;

public:
    AES_SIMD(const BYTE key[4 * Nk], bool decrypting)
        : decrypting(decrypting)
    {
        KeyExpansion(key);
    }

    void Cipher(const BYTE in[4 * Nb], BYTE out[4 * Nb]);
    void EqInvCipher(const BYTE in[4 * Nb], BYTE out[4 * Nb]);

protected:
    void KeyExpansion(const BYTE key[4 * Nk]);
};

w配列(round用キー)を生成する

暗号化の前にKey Expansionで暗号キー(Nkワードのデータ)からw配列を作ります。

以下はFIPS 197のKeyExpansionのアルゴリズムです。

Rcon配列はFIPS 197で定義された方法で計算された定数の配列で、具体的な値は以下のようになります。AES暗号で実際に使われるのは[1]~[10]の範囲だけです。

static const BYTE Rcon[] = {
 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a,
};

AESKEYGENASSIST - AES KEY GENeration ASSIST

AESKEYGENASSIST xmm1, xmm2/m128, imm8
VAESKEYGENASSIST xmm1, xmm2/m128, imm8
__m128i  _mm_aeskeygenassist_si128(__m128i temp, const int Rcon);

AESKEYGENASSISTは、①にtempの値を入れて、②にRcon[i/Nk]の値を指定して実行すると、の2つの式の値を計算して③に返してくれます。

AESKEYGENASSIST命令は2つの値を返してきますが、周回ごとに使うのはどちらか一方だけなので、もう一方は捨てることになります。また、2セット同時に計算してくれますが、これは2つの別の暗号キーを同時に処理するのでなければ使いようがないと思います。

なお、の値を使うのはAES-256の場合だけです。

サンプル

この命令にはちょっとやっかいな点があって、②の値に定数しか書けません。そのためRconの配列を用意しておいて周回ごとに値を取ってきて入れるということができません。上のPseudo Codeの流れのままコーディングするなら以下のようになるかと思います。

//KeyExpansion(byte key[4*Nk], word w[Nb*(Nr+1)], Nk)
//begin
void AES_SIMD::KeyExpansion(const BYTE key[4 * Nk])
{
    //word  temp
    //i = 0

    //while (i < Nk)
    //w[i] = word(key[4*i], key[4*i+1], key[4*i+2], key[4*i+3])
    //i = i+1
    //end while
    memcpy(w.w32, key, 4 * Nk);

    //i = Nk
    int i = Nk;
    //while (i < Nb * (Nr+1)]
    for (; i < Nb * (Nr + 1); i++) {
        //temp = w[i-1]
        AESWORD temp = w.w32[i - 1];

        //if (i mod Nk = 0)
        if (i % Nk == 0) {
            //temp = SubWord(RotWord(temp)) xor Rcon[i/Nk]
            __m128i t;
            t.m128i_u32[1] = temp;
            switch (i / Nk) {
            case 1:
                t = _mm_aeskeygenassist_si128(t, 0x01);
                break;
            case 2:
                t = _mm_aeskeygenassist_si128(t, 0x02);
                break;
            case 3:
                t = _mm_aeskeygenassist_si128(t, 0x04);
                break;
            case 4:
                t = _mm_aeskeygenassist_si128(t, 0x08);
                break;
            case 5:
                t = _mm_aeskeygenassist_si128(t, 0x10);
                break;
            case 6:
                t = _mm_aeskeygenassist_si128(t, 0x20);
                break;
            case 7:
                t = _mm_aeskeygenassist_si128(t, 0x40);
                break;
            case 8:
                t = _mm_aeskeygenassist_si128(t, 0x80);
                break;
            case 9:
                t = _mm_aeskeygenassist_si128(t, 0x1b);
                break;
            case 10:
                t = _mm_aeskeygenassist_si128(t, 0x36);
                break;
            default:
//              ASSERT(0);
                break;
            }

            temp = t.m128i_u32[1];
        }
        //else if (Nk > 6 and i mod Nk = 4)
        else if (Nk > 6 && (i % Nk) == 4) {
            //temp = SubWord(temp)
            __m128i t;
            t.m128i_u32[1] = temp;
            t = _mm_aeskeygenassist_si128(t, 0);
            temp = t.m128i_u32[0];
        }
        //end if

        //w[i] = w[i-Nk] xor temp
        w.w32[i] = w.w32[i - Nk] ^ temp;

        //i = i + 1
    }
    //end while
    //end

}

たぶん、ループするのではなく、すべての周回相当の処理をインラインでずらっと並べて書くのがいいんじゃないかと思います。

AES-256のKeyExpansionをループなしで書くとこんな感じです。

//KeyExpansion(byte key[4*Nk], word w[Nb*(Nr+1)], Nk)
//begin
void AES256_SIMD::KeyExpansion(const BYTE key[4 * Nk])
{
    //word  temp
    //i = 0

    //while (i < Nk)
    //w[i] = word(key[4*i], key[4*i+1], key[4*i+2], key[4*i+3])
    //i = i+1
    //end while
    memcpy(w.w32, key, 4 * Nk);

    //i = Nk
    //while (i < Nb * (Nr+1)]
    //temp = w[i-1]
    //if (i mod Nk = 0)
    //temp = SubWord(RotWord(temp)) xor Rcon[i/Nk]
    //else if (Nk > 6 and i mod Nk = 4)
    //temp = SubWord(temp)
    //end if
    //w[i] = w[i-Nk] xor temp
    //i = i + 1
    //end while

#define KX_0(i, Rcon) w.w32[i] =                \
    _mm_extract_epi32(                          \
        _mm_aeskeygenassist_si128(              \
            _mm_slli_si128(                     \
                _mm_cvtsi32_si128(w.w32[i-1]),  \
                4                               \
            ),                                  \
            Rcon                                \
        ),                                      \
        1                                       \
    ) ^ w.w32[i - Nk];

#define KX_4(i) w.w32[i] =                      \
    _mm_cvtsi128_si32(                          \
        _mm_aeskeygenassist_si128(              \
            _mm_slli_si128(                     \
                _mm_cvtsi32_si128(w.w32[i-1]),  \
                4                               \
            ),                                  \
            0                                   \
        )                                       \
    ) ^ w.w32[i - Nk];      

#define KX(i) w.w32[i] = w.w32[i-1] ^ w.w32[i-Nk];

    KX_0(8, 0x01)   KX(9)   KX(10)  KX(11)      KX_4(12)    KX(13)  KX(14)  KX(15)
    KX_0(16, 0x02)  KX(17)  KX(18)  KX(19)      KX_4(20)    KX(21)  KX(22)  KX(23)
    KX_0(24, 0x04)  KX(25)  KX(26)  KX(27)      KX_4(28)    KX(29)  KX(30)  KX(31)
    KX_0(32, 0x08)  KX(33)  KX(34)  KX(35)      KX_4(36)    KX(37)  KX(38)  KX(39)
    KX_0(40, 0x10)  KX(41)  KX(42)  KX(43)      KX_4(44)    KX(45)  KX(46)  KX(47)
    KX_0(48, 0x20)  KX(49)  KX(50)  KX(51)      KX_4(52)    KX(53)  KX(54)  KX(55)
    KX_0(56, 0x40)  KX(57)  KX(58)  KX(59)
    //end
}

復号

FIPS 197では、復号のアルゴリズムとして、暗号化の手順を逆にしただけの「InvCipher」と、それと等価な結果を得られる「EqInvCipher」という2つが定義されていますが、AESNIがサポートするのは後者のみです。

EqInvCipherでは、w配列の生成処理に後処理をちょっと追加して作ったdw配列を使います。dw配列の作り方は後述します。

AESDEC - AES DECrypt
AESDECLAST -AES DECrypt LAST

AESDEC xmm1, xmm2/m128
__m128i  _mm_aesdec_si128(__m128i state, __m128i dw);
AESDECLAST xmm1, xmm2/m128
__m128i  _mm_aesdeclast_si128(__m128i state, __m128i dw);

VAESDEC xmm1, xmm2, xmm3/m128
VAESDECLAST xmm1, xmm2, xmm3/m128

AESDECは、の4つの関数を呼び出している部分(1ラウンド分)を1命令で処理します。

は最終ラウンドですが、この回だけ呼び出す関数がひとつ少ないので、AESDECLAST命令で処理します。

①にラウンド前のstate配列、②にそのラウンドで使うdw配列の要素を入れて実行すると、③にラウンド後のstate配列が返ります。

サンプル

//EqInvCipher(byte in[4*Nb], byte out[4*Nb], word dw[Nb*(Nr+1)])
//begin
void AES_SIMD::EqInvCipher(const BYTE in[4 * Nb], BYTE out[4 * Nb])
{
//  ASSERT(decrypting);

    //byte  state[4,Nb]
    //state = in
    __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in));

    //AddRoundKey(state, dw[Nr*Nb, (Nr+1)*Nb-1])
    state = _mm_xor_si128(state, dw.w128[Nr]);

    //for round = Nr-1 step -1 downto 1
    for (int round = Nr - 1; round >= 1; round--) {
        //InvSubBytes(state)
        //InvShiftRows(state)
        //InvMixColumns(state)
        //AddRoundKey(state, dw[round*Nb, (round+1)*Nb-1])
        state = _mm_aesdec_si128(state, dw.w128[round]);

    //end for
    }

    //InvSubBytes(state)
    //InvShiftRows(state)
    //AddRoundKey(state, dw[0, Nb-1])
    state = _mm_aesdeclast_si128(state, dw.w128[0]);

    //out = state
    _mm_storeu_si128(reinterpret_cast<__m128i *>(out), state);
    //end
}

dw配列を生成する

EqInvCipher用のdw配列は、前述のKeyExpansion処理で作ったw配列に、以下の後処理をすることで作ります。

AESIMC - AES InvMixColumns

AESIMC xmm1, xmm2/m128
VAESIMC xmm1, xmm2/m128
__m128i  _mm_aesimc_si128(__m128i w);

上のInvMixColumns()の部分を処理する命令です。w配列の要素を①に入れて実行すると、後処理後のdw配列の要素が得られます。

サンプル

前述のKeyExpansionのサンプルの最後に追加するコードです。

    // EqInvCipher用の追加処理
    if (decrypting) {
        //for i = 0 step 1 to (Nr+1)*Nb-1
        //dw[i] = w[i]
        //end for
        dw = w;

        //for round = 1 step 1 to Nr-1
        for (int round = 1; round <= Nr - 1; round++) {
            //InvMixColumns(dw[round*Nb, (round+1)*Nb-1])
            dw.w128[round] = _mm_aesimc_si128(dw.w128[round]);
        //end for
        }
    }

サンプルソースファイル

上のサンプル、及びループなし版のAES-128、AES-192、AES-256のサンプルのソースファイルを置いておきますのでご自由にご利用ください。

AES_SIMD_2.zip


x86/x64 SIMD命令一覧表  フィードバック

ホームページ http://www.officedaytime.com/