2005年10月24日

普通のやつらの下を行け: assert_caller()

以前に、低レベルプログラミングを愛好する知人が「普通のやつらの下を行け」を口癖にしていました。当時は何を言っているのかと聞き流していましたが、自分も最近になってようやく低レベルプログラミングのおもしろさがわかってきました。今回は「普通のやつらの下を行け」企画の第一弾として assert_caller() なるものを作ってみたいと思います。

 

assert_caller() とは

assert_caller() とは、特定の関数からの関数呼び出しだけを通すためのアサーションです。たとえば、次のように foo() の先頭で assert_caller(main) と書いた場合、 foo() は main() からしか呼び出せなくなります。他の関数から foo() を呼び出した場合はエラーメッセージとともに異常終了します。

void foo() {
  assert_caller(main);
  ...
}

これが便利なのかよくわかりませんが、こういうものがあれば、C でも protected や friend に似たアクセス制限を実現できるかもしれません。assert_caller() は鵜飼さんに命名されました。

assert_caller() の実装

今いる関数がどの関数から呼ばれてきたのかがわかれば assert_caller() は簡単に実現できます。しかし、C言語には、どの関数から呼ばれたかを調べる方法はありません。ここでは GCC の独自拡張と ELF バイナリ内に含まれる情報を利用してこの問題に対処することにします。

GCC では実行中の関数のリターンアドレスを __builtin_return_address(0) で取得できます。次のプログラムをコンパイルして実行すると、手元の環境では 80483b4 と表示されました。

#include <stdio.h>
int foo() {
    printf("%x\n", __builtin_return_address(0));
}
int main() {
    foo();
    return 0;
}

次に、この実行形式ファイルのシンボルの情報を nm で見てみます。nm の出力によると、 main() の開始アドレスは 804839f、長さは 1cバイトとなっています。つまり main() の実装はアドレス 804839f からはじまり、 80483ba で終わることがわかります。シンボルの情報はバイナリを strip すると削られてしまいますが、普通にコンパイルした時点では残っています。

% nm -S a.out |grep -w main
0804839f 0000001c T main

foo() のリターンアドレスアドレスはさきほど調べたとおり、 80483b4 ですから、main() の範囲 (804839f ~ 80483ba) に収まっています。これで foo() が main() から確かに呼ばれていることがわかりました。

このように、GCC の拡張と ELF バイナリ内の情報を用いれば、どの関数がどこから呼ばれているかを調べることができます。assert_caller() はこれらを使って実装することにします。以下がそのコードです。

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <unistd.h>
#include <assert.h>

struct symlist {
    char *name;
    size_t start;
    size_t length;
    struct symlist *next;
} *symbols;

void init_symbols()
{
    char line[1024];
    char command[1024];
    snprintf(command, sizeof(command), "nm -S /proc/%d/exe", getpid());
    FILE *fp = popen(command, "r");
    if (fp == NULL) {
        perror(command);
    }
    while (fgets(line, sizeof(line), fp) != NULL) {
        size_t start, length;
        char name[1024];
        if (sscanf(line, "%x %x T %s", &start, &length, name) == 3) {
            struct symlist *s = malloc(sizeof(struct symlist));
            if (s == NULL) {
                perror("malloc");
            }
            s->name = strdup(name);
            s->start = start;
            s->length = length;
            s->next = symbols;
            symbols = s;
        }
    }
}

#define assert_caller(func_name) do {                              \
        const char *caller = "(unknown)";                          \
        size_t from = (size_t)__builtin_return_address(0);         \
        struct symlist *s;                                         \
        for (s = symbols; s != NULL; s = s->next) {                \
            if (s->start <= from && from < s->start + s->length) { \
                caller = s->name;                                  \
            }                                                      \
        }                                                          \
        if (strcmp(caller, #func_name) != 0) {                     \
            fprintf(stderr,                                        \
                    "%s:%d: assert_caller(%s) failed. "            \
                    "called from %s\n",                            \
                    __FILE__, __LINE__, #func_name, caller);       \
            abort();                                               \
        }                                                          \
    } while(0)

void foo() {
    assert_caller(main);
    printf("hello\n");
}

void bar() {
    foo();
}

int main()
{
    init_symbols();
    foo();
    bar();
    return 0;
}

このプログラムをコンパイルして実行すると、main() から呼ばれた foo() は問題なく実行され、 bar() から呼ばれた場合は assert_caller() のところで落ちます。

% ./a.out
hello
assert_caller.c:60: assert_caller(main) failed. called from bar
zsh: 2898 abort (core dumped)  ./a.out

なお、上のコードでは nm の出力を使ってシンボルの情報を取得しています。 外部コマンドを呼び出さないでシンボルの情報を取得するには libbfd というライブラリを使うと便利です。libbfd を使ったコードの例としては、鵜飼さんの livepatch が参考になります。また、livepatch では /proc/PID/maps を見て共有ライブラリ内のシンボル情報の取得およびアドレスの再配置の計算を行っていますが、ここでは実行形式ファイルに含まれるシンボルだけを対象としました。

まとめ

「普通のやつらの下を行け」という割に、肝心な部分は nm を呼び出すという軟弱なコードになってしまいましたが、ひとまず assert_caller() を作るという目的は達成できました。 GCC の __builtin_return_address() と ELF のシンボルの情報を使えば他にもいろいろおもしろいことができると思います。

ちなみに、「普通のやつらの下を行け」は普通のやつらの上を行けをもじったものです。

追記

鵜飼さんによると、動的共有オブジェクト内のシンボルは dladdr(3) で取得できるとのこと。これは便利。.dynsym セクションは strip しても消えないところもポイントです。さっそく先人に下を行かれてしまいました。