Saturday, August 22, 2009

Корень из целого числа в compile-time часть 2

Не так давно в этом блоге мелькала статья на тему вычисления квадратного корня из целого числа в compile-time. Сейчас будет небольшое продолжение.

Некоторое время назад постоянный читатель этого блога прислал мне ссылку на задачку. Сайт Insidecpp иногда очень интересно почитать, и даже пафос и надменность автора как-то можно перетерпеть, когда видишь сколько он там всего понаписал. Прореагировал я примерно так: "чтобы моё решение подошло к этой задаче, нужно исправить всего 1 строчку". А потом задумался.

Автор предлагает реализовать вычисление квадратного корня в compile-time, но только с формулировкой "найти целое, ближайшее к правильному ответу", а не "наибольшее целое, квадрат которого не превышает":
Напишите вычисление квадратного корня в compile-time-е. Чтобы усложнить задачу я предлагаю следующее. Если квадратный корень от заданного числа не может быть вычислен в целых, то вашим результатом должно быть ближайшее целое к реальному результату. Обратите внимание, что речь идет именно о ближайшем целом. То есть не ближайшее меньшее результата, а просто ближайшее. Оно может быть как больше, так и меньше реального результата.


Т.е., если в моей задаче корень из 3 - это решительно 1, то в задаче автора - это должно быть 2. Решение автора меня очень порадовало своей краткостью, но к сожалению это решение при ближайшем рассмотрении оказалось неверным. Сейчас я расскажу как поправить код из предыдущей статьи, чтобы получить решение для задачки с insidecpp.ru.

Пусть дано:
y * y = x
0 < a < y < b
a и b - это целые числа, ближайшие к y. Нам нужно из a и b выбрать число, которое отличается от y минимально.
Т.е. d1=y-a, d2=b-y
Если d1 < d2, тогда нужно выбрать a, если d1 > d2, тогда - b.

т.е. если буквой z обозначить знак, который ставится между d1 и d2 при сравнении, то получим:
d1 z d2
y-a z b-y
2*y z a+b
4*y*y z (a+b)*(a+b)
как известно, y*y это x, т.е.:
4*x z (a+b)*(a+b)
т.к. в нашем случае вместо a и b используются соответственно lo и lo+1, результат:
4*x z 4*lo*lo + 4*lo + 1
или, чтобы избежать переполнения,
4*(x-lo*lo) z 4*lo + 1
если 4*(x-lo*lo) > 4*lo + 1, выбераем - lo+1. Иначе - lo.

Единственное место в коде, которое необходимо поправить - это:

template<unsigned x, unsigned lo, unsigned y>
struct sqrt<x, lo, 1, y, less>
{
enum { value = lo * lo + 2 * lo <= x - 1 ? (lo + 1) : lo };
};

Меняем логику на:

template<unsigned x, unsigned lo, unsigned y>
struct sqrt<x, lo, 1, y, less>
{
enum { value = 4*(x-lo*lo) > 4*lo + 1 ? lo+1:lo };
};

Задача решена.

Monday, August 10, 2009

Корень из целого числа в компайл-тайм

Я люблю С++. Я ненавижу C++. В пользу обоих утверждений можно привести как минимум по десятку доводов, но независимо от того, какие доводы перевесят, C++ - это очень интересный язык. Одной из важнейших особенностей C++ является механизм шаблонов. С одной стороны в сообществе C++ принято говорить, что шаблоны - это мощнейшее средство, которое позволяет свести дублирование кода к нулю. Яркий пример - стандартная библиотека. Думаю, никто не станет спорить, что это вещь крайне полезная и вполне себе удобная. Но с другой стороны, всегда найдутся люди, которые найдут любому средству такое применение, что волосы дыбом становятся (см. хотя бы boost spirit). Хорошее это средство или плохое, в любом случае оно позволяет сделать процесс написания кода крайне увлекательным занятием.

Эта статья о том, как вычислить квадратный корень из целого числа на этапе компиляции.

Если читатель хоть немного знаком с C++, могу предположить, что он уже пробовал реализовать вычисление факториала на этапе компиляции. Например, так:

template<unsigned x>
struct factorial
{
enum { value = x * factorial<x-1>::value };
};

template<>
struct factorial<1>
{
enum { value = 1 };
};

Чтобы вычислить факториал, достаточно написать:
factorial<5>::value

Это выражение обладает всеми свойствами, которыми обладает константа. Значение этого выражения можно использовать для объявления массива, для специализации/инстанциирования шаблонов и для всего чего угодно остального, т.е. значение известно на этапе компиляции.

Но вернёмся к извлечению корня. Вообще квадратным корнем из числа X называется такое неотрицательное число Y, квадрат которого равен X. Для целых чисел эта формулировка не подходит. Мы будем отталкиваться от формулировки: квадратными корнем из целого числа X называется такое наибольшее целое число Y, квадрат которого не превышает X.

Для начала я опишу алгоритм, который мы будем реализовывать. Пусть нам дано цело число x, а корень, который мы ищем - y. Пусть так же даны 2 числа lo и range, такие, что изначально lo = 0, range = x. На каждой итерации работы y будет серединой диапазона и на каждой итерации мы будем этот диапазон корректировать. Вот простой пример:
x=36
1. lo = 0, range = 36, y = lo + range / 2 = 18. 18*18 > 36, значит берём меньшую половину диапазона.
2. lo = 0, range = 18, y = lo + range / 2 = 9. 9*9 > 36, берём меньшую.
3. lo = 0, range = 9, y = lo + range / 2 = 4. 4*4 < 36, берём большую.
4. lo = 4, range = 4, y = lo + range / 2 = 6. 6*6 = 36 - это ответ.

Возможны и другие варианты развития событий, но о них - позже.

Начнём.

template
<
unsigned x, // число, корень которого ищем
unsigned lo = 0, // начало диапазона
unsigned range = x, // ширина диапазона
unsigned y = lo + range / 2 // значение корня для текущей итерации
>
struct sqrt;

Использование шаблонных параметров со значениями по умолчанию позволяет нам делать нужные вычисления. При инстанциировании этого шаблона в виде sqrt<36>, x будет равен 36, lo - 0, range - 36, y = 18. Что делать дальше? Нам нужно каким-то образом реализовать логику выбора - нижняя половина диапазона, верхняя, ответ уже найден. Для этого опишем вспомогательный класс:

// состояние алгоритма sqrt - уменьшать, увеличивать, вычислено
enum state { less, greater, equal };

// получение состояния алгоритма
template
<
unsigned x, // нужное число
unsigned y // кандидат на корень
>
struct sqrt_state
{
enum { value = x == y*y ? equal : (x < y*y ? greater : less) };
};

sqrt_state<x, y> проверяет отношения равенства/неравенства между x и y*y. В случае если y*y < x, результат sqrt_state::value == less, если y*y > x, то greater, иначе equal. Изменим шаблон sqrt следующим образом:

// sqrt
template
<
unsigned x, // число для извлечения корня
unsigned lo = 0, // нижняя граница диапазона
unsigned range = x, // ширина диапазона
unsigned y = lo + range / 2, // середина диапазона
unsigned state = sqrt_state<x, y>::value // состояние алгоритма
>
struct sqrt;

Теперь при инстанциировании шаблона sqrt, ему известно что нужно делать. В случае less он должен обращаться сам к себе, передавая в качестве аргумента верхний поддиапазон, в случае greater - нижний, а в случае equal рекурсия должна заканчиваться, т.к. ответ найден:

// квадрат проверяемого числа меньше, чем нужное число
template<unsigned x, unsigned lo, unsigned range, unsigned y>
struct sqrt<x, lo, range, y, less>
{
enum { value = sqrt<x, y, range / 2>::value };
};

// квадрат проверяемого числа больше, чем нужное число
template<unsigned x, unsigned lo, unsigned range, unsigned y>
struct sqrt<x, lo, range, y, greater>
{
enum { value = sqrt<x, lo, range / 2>::value };
};

// квадрат проверяемого числа равен нужному числу
template<unsigned x, unsigned lo, unsigned range, unsigned y>
struct sqrt<x, lo, range, y, equal>
{
enum { value = y };
};

Теперь, если проверить написанное с помощью:
sqrt<36>::value

Получим правильный ответ - 6. А если попробовать 37? Ошибка компиляции. Почему? Всё очень просто. Дело дошло до того, когда range стал равным 0. При этом специализация
struct sqrt<x, lo, range, y, less>

попыталась вызвать саму себя с теми же параметрами, с которыми она была вызвана. Как же быть? Тут мы столкнулись с ситуацией, когда целочисленный корень из числа не существует. Добавим ещё одну специализацию sqrt:

// квадрат проверяемого числа меньше, чем нужное число, диапазон равен 1
// наибольшее целое, квадрат которого не превышает заданное
template<unsigned x, unsigned lo, unsigned y>
struct sqrt<x, lo, 1, y, less>
{
enum { value = (lo+1)*(lo+1) <= x ? (lo + 1) : lo };
};

Эта специализация начинает работать когда range=1. Когда это происходит? Когда корень числа находится между двумя целыми числами. Для 37 эти числа - 6 и 7. При вычислении value в этой специализации мы проверяем какое из двух чисел "лучше" подходит - 6 или 7. В данном случае выбераем 6. Вроде бы всё. Проверим как работает наш sqrt для других чисел. Для этого напишем вспомогательный класс static_assert (аналогичная штука уже есть в c++0x, но сделаем вид, что у нас нет c++0x):

static_assert<sqrt<0ul>::value == 0ul>();
static_assert<sqrt<1ul>::value == 1ul>();
static_assert<sqrt<2ul>::value == 1ul>();
static_assert<sqrt<3ul>::value == 1ul>();
static_assert<sqrt<4ul>::value == 2ul>();
static_assert<sqrt<5ul>::value == 2ul>();
static_assert<sqrt<6ul>::value == 2ul>();
static_assert<sqrt<7ul>::value == 2ul>();
static_assert<sqrt<8ul>::value == 2ul>();
static_assert<sqrt<9ul>::value == 3ul>();
static_assert<sqrt<13ul>::value == 3ul>();
static_assert<sqrt<100ul>::value == 10ul>();
static_assert<sqrt<256ul>::value == 16ul>();
static_assert<sqrt<257ul>::value == 16ul>();
static_assert<sqrt<2048ul>::value == 45ul>();
static_assert<sqrt<9801ul>::value == 99ul>();
static_assert<sqrt<65533ul>::value == 255ul>();
static_assert<sqrt<65534ul>::value == 255ul>();
static_assert<sqrt<65536ul>::value == 256ul>();
static_assert<sqrt<999999ul>::value == 999ul>();
static_assert<sqrt<1000000ul>::value == 1000ul>();
static_assert<sqrt<12345678ul>::value == 3513ul>();
static_assert<sqrt<123456789ul>::value == 11111ul>();
static_assert<sqrt<4294836223ul>::value == 65534ul>();
static_assert<sqrt<4294836224ul>::value == 65534ul>();
static_assert<sqrt<0xfffeul*0xfffeul>::value == 0xfffeul>();
static_assert<sqrt<0xfffful*0xfffful>::value == 0xfffful>();
static_assert<sqrt<0x10001ul*0xfffful>::value == 0xfffful>();

Пробуем откомпилировать и... компилятор сообщает нам о множестве ошибок и предупреждений. В чём же дело? Начнём с того, что в некоторых местах происходит целочисленное переполнение.

template
<
unsigned x, // нужное число
unsigned y // кандидат на корень
>
struct sqrt_state
{
enum { value = x == y*y ? equal : (x < y*y ? greater : less) };
};

Видим y*y. Проблема связана с тем, что в случае если unsigned - это 32-разрядное целое без знака, то область определения y для выражения y*y - это 1<<(32/2)-1, т.е. 0xffff. Но это частный случай. В общем случае можно реализовать дополнительный класс, который позволит нам проверять y на невыходимость за границы дозволенного его квадрата:

// проверка не превышение наибольшего числа, которое
// можно возвести в квадрат без переполнения
template<unsigned x>
struct is_greater_than_max_squarable
{
enum
{
max_squarable = (1 << (std::numeric_limits<unsigned>::digits / 2)) - 1,
value = x > max_squarable
};
};

Для использования std::numeric_limits нужно подключить стандартный заголовок limits. Название класса говорит само за себя. Исправим логическую часть алгоритма следующим образом:

// получение состояния алгоритма
template
<
unsigned x, // нужное число
unsigned y, // кандидат на корень
bool y_is_too_large = is_greater_than_max_squarable<y>::value
>
struct sqrt_state;

// y превышает наибольшее, которое можно возводить в квадрат
template<unsigned x, unsigned y>
struct sqrt_state<x, y, true>
{
enum { value = greater };
};

// y не превышает наибольшее, которое можно возводить в квадрат
template<unsigned x, unsigned y>
struct sqrt_state<x, y, false>
{
enum { value = x == y*y ? equal : (x < y*y ? greater : less) };
};

Идея в том, что квадрат числа, большего чем 0xffff даст число, большее чем умещается в переменной типа unsigned. Наша модификация заставляет sqrt_state расценивать число как слишком большое в случае если оно больше, чем 0xffff.

Снова компилируем и... снова ошибка. На этот раз одна. Речь идёт о строке:
static_assert<sqrt<0x10001ul*0xfffful>::value == 0xfffful>();

Что же здесь могло произойти? Реальная проблема находится здесь:
enum { value = (lo+1)*(lo+1) <= x ? (lo + 1) : lo };

lo у нас вполне может оказаться равным 0xffff, соответственно (0xffff+1)*(0xffff+1) это 0x10000*0x10000 - а это много. Проблема решается следующим волшебным образом:
enum { value = lo * lo + 2 * lo <= x - 1 ? (lo + 1) : lo };

В таком варианте переполнения не происходит, хотя исправленное условное выражение означает то же самое, что и исходное. Кстати, вопросы по поводу разницы между a+b+c и a+c+b иногда задают на собеседованиях :)

В качестве заключения: sqrt.