Я люблю С++. Я ненавижу C++. В пользу обоих утверждений можно привести как минимум по десятку доводов, но независимо от того, какие доводы перевесят, C++ - это очень интересный язык. Одной из важнейших особенностей C++ является механизм шаблонов. С одной стороны в сообществе C++ принято говорить, что шаблоны - это мощнейшее средство, которое позволяет свести дублирование кода к нулю. Яркий пример - стандартная библиотека. Думаю, никто не станет спорить, что это вещь крайне полезная и вполне себе удобная. Но с другой стороны, всегда найдутся люди, которые найдут любому средству такое применение, что волосы дыбом становятся (см. хотя бы
boost spirit). Хорошее это средство или плохое, в любом случае оно позволяет сделать процесс написания кода крайне увлекательным занятием.
Эта статья о том, как вычислить квадратный корень из целого числа на этапе компиляции.
Если читатель хоть немного знаком с C++, могу предположить, что он уже пробовал реализовать вычисление факториала на этапе компиляции. Например, так:
template<unsigned x>
struct factorial
{
enum { value = x * factorial<x-1>::value };
};
template<>
struct factorial<1>
{
enum { value = 1 };
};
Чтобы вычислить факториал, достаточно написать:
factorial<5>::value
Это выражение обладает всеми свойствами, которыми обладает константа. Значение этого выражения можно использовать для объявления массива, для специализации/инстанциирования шаблонов и для всего чего угодно остального, т.е. значение известно на этапе компиляции.
Но вернёмся к извлечению корня. Вообще квадратным корнем из числа X называется такое неотрицательное число Y, квадрат которого равен X. Для целых чисел эта формулировка не подходит. Мы будем отталкиваться от формулировки: квадратными корнем из целого числа X называется такое наибольшее целое число Y, квадрат которого не превышает X.
Для начала я опишу алгоритм, который мы будем реализовывать. Пусть нам дано цело число x, а корень, который мы ищем - y. Пусть так же даны 2 числа lo и range, такие, что изначально lo = 0, range = x. На каждой итерации работы y будет серединой диапазона и на каждой итерации мы будем этот диапазон корректировать. Вот простой пример:
x=36
1. lo = 0, range = 36, y = lo + range / 2 = 18. 18*18 > 36, значит берём меньшую половину диапазона.
2. lo = 0, range = 18, y = lo + range / 2 = 9. 9*9 > 36, берём меньшую.
3. lo = 0, range = 9, y = lo + range / 2 = 4. 4*4 < 36, берём большую.
4. lo = 4, range = 4, y = lo + range / 2 = 6. 6*6 = 36 - это ответ.
Возможны и другие варианты развития событий, но о них - позже.
Начнём.
template
<
unsigned x, // число, корень которого ищем
unsigned lo = 0, // начало диапазона
unsigned range = x, // ширина диапазона
unsigned y = lo + range / 2 // значение корня для текущей итерации
>
struct sqrt;
Использование шаблонных параметров со значениями по умолчанию позволяет нам делать нужные вычисления. При инстанциировании этого шаблона в виде sqrt<36>, x будет равен 36, lo - 0, range - 36, y = 18. Что делать дальше? Нам нужно каким-то образом реализовать логику выбора - нижняя половина диапазона, верхняя, ответ уже найден. Для этого опишем вспомогательный класс:
// состояние алгоритма sqrt - уменьшать, увеличивать, вычислено
enum state { less, greater, equal };
// получение состояния алгоритма
template
<
unsigned x, // нужное число
unsigned y // кандидат на корень
>
struct sqrt_state
{
enum { value = x == y*y ? equal : (x < y*y ? greater : less) };
};
sqrt_state<x, y> проверяет отношения равенства/неравенства между x и y*y. В случае если y*y < x, результат sqrt_state
::value == less, если y*y > x, то greater, иначе equal. Изменим шаблон sqrt следующим образом:
// sqrt
template
<
unsigned x, // число для извлечения корня
unsigned lo = 0, // нижняя граница диапазона
unsigned range = x, // ширина диапазона
unsigned y = lo + range / 2, // середина диапазона
unsigned state = sqrt_state<x, y>::value // состояние алгоритма
>
struct sqrt;
Теперь при инстанциировании шаблона sqrt, ему известно что нужно делать. В случае less он должен обращаться сам к себе, передавая в качестве аргумента верхний поддиапазон, в случае greater - нижний, а в случае equal рекурсия должна заканчиваться, т.к. ответ найден:
// квадрат проверяемого числа меньше, чем нужное число
template<unsigned x, unsigned lo, unsigned range, unsigned y>
struct sqrt<x, lo, range, y, less>
{
enum { value = sqrt<x, y, range / 2>::value };
};
// квадрат проверяемого числа больше, чем нужное число
template<unsigned x, unsigned lo, unsigned range, unsigned y>
struct sqrt<x, lo, range, y, greater>
{
enum { value = sqrt<x, lo, range / 2>::value };
};
// квадрат проверяемого числа равен нужному числу
template<unsigned x, unsigned lo, unsigned range, unsigned y>
struct sqrt<x, lo, range, y, equal>
{
enum { value = y };
};
Теперь, если проверить написанное с помощью:
sqrt<36>::value
Получим правильный ответ - 6. А если попробовать 37? Ошибка компиляции. Почему? Всё очень просто. Дело дошло до того, когда range стал равным 0. При этом специализация
struct sqrt<x, lo, range, y, less>
попыталась вызвать саму себя с теми же параметрами, с которыми она была вызвана. Как же быть? Тут мы столкнулись с ситуацией, когда целочисленный корень из числа не существует. Добавим ещё одну специализацию sqrt:
// квадрат проверяемого числа меньше, чем нужное число, диапазон равен 1
// наибольшее целое, квадрат которого не превышает заданное
template<unsigned x, unsigned lo, unsigned y>
struct sqrt<x, lo, 1, y, less>
{
enum { value = (lo+1)*(lo+1) <= x ? (lo + 1) : lo };
};
Эта специализация начинает работать когда range=1. Когда это происходит? Когда корень числа находится между двумя целыми числами. Для 37 эти числа - 6 и 7. При вычислении value в этой специализации мы проверяем какое из двух чисел "лучше" подходит - 6 или 7. В данном случае выбераем 6. Вроде бы всё. Проверим как работает наш sqrt для других чисел. Для этого напишем вспомогательный класс static_assert (аналогичная штука уже есть в c++0x, но сделаем вид, что у нас нет c++0x):
static_assert<sqrt<0ul>::value == 0ul>();
static_assert<sqrt<1ul>::value == 1ul>();
static_assert<sqrt<2ul>::value == 1ul>();
static_assert<sqrt<3ul>::value == 1ul>();
static_assert<sqrt<4ul>::value == 2ul>();
static_assert<sqrt<5ul>::value == 2ul>();
static_assert<sqrt<6ul>::value == 2ul>();
static_assert<sqrt<7ul>::value == 2ul>();
static_assert<sqrt<8ul>::value == 2ul>();
static_assert<sqrt<9ul>::value == 3ul>();
static_assert<sqrt<13ul>::value == 3ul>();
static_assert<sqrt<100ul>::value == 10ul>();
static_assert<sqrt<256ul>::value == 16ul>();
static_assert<sqrt<257ul>::value == 16ul>();
static_assert<sqrt<2048ul>::value == 45ul>();
static_assert<sqrt<9801ul>::value == 99ul>();
static_assert<sqrt<65533ul>::value == 255ul>();
static_assert<sqrt<65534ul>::value == 255ul>();
static_assert<sqrt<65536ul>::value == 256ul>();
static_assert<sqrt<999999ul>::value == 999ul>();
static_assert<sqrt<1000000ul>::value == 1000ul>();
static_assert<sqrt<12345678ul>::value == 3513ul>();
static_assert<sqrt<123456789ul>::value == 11111ul>();
static_assert<sqrt<4294836223ul>::value == 65534ul>();
static_assert<sqrt<4294836224ul>::value == 65534ul>();
static_assert<sqrt<0xfffeul*0xfffeul>::value == 0xfffeul>();
static_assert<sqrt<0xfffful*0xfffful>::value == 0xfffful>();
static_assert<sqrt<0x10001ul*0xfffful>::value == 0xfffful>();
Пробуем откомпилировать и... компилятор сообщает нам о множестве ошибок и предупреждений. В чём же дело? Начнём с того, что в некоторых местах происходит целочисленное переполнение.
template
<
unsigned x, // нужное число
unsigned y // кандидат на корень
>
struct sqrt_state
{
enum { value = x == y*y ? equal : (x < y*y ? greater : less) };
};
Видим y*y. Проблема связана с тем, что в случае если unsigned - это 32-разрядное целое без знака, то область определения y для выражения y*y - это 1<<(32/2)-1, т.е. 0xffff. Но это частный случай. В общем случае можно реализовать дополнительный класс, который позволит нам проверять y на невыходимость за границы дозволенного его квадрата:
// проверка не превышение наибольшего числа, которое
// можно возвести в квадрат без переполнения
template<unsigned x>
struct is_greater_than_max_squarable
{
enum
{
max_squarable = (1 << (std::numeric_limits<unsigned>::digits / 2)) - 1,
value = x > max_squarable
};
};
Для использования std::numeric_limits нужно подключить стандартный заголовок limits. Название класса говорит само за себя. Исправим логическую часть алгоритма следующим образом:
// получение состояния алгоритма
template
<
unsigned x, // нужное число
unsigned y, // кандидат на корень
bool y_is_too_large = is_greater_than_max_squarable<y>::value
>
struct sqrt_state;
// y превышает наибольшее, которое можно возводить в квадрат
template<unsigned x, unsigned y>
struct sqrt_state<x, y, true>
{
enum { value = greater };
};
// y не превышает наибольшее, которое можно возводить в квадрат
template<unsigned x, unsigned y>
struct sqrt_state<x, y, false>
{
enum { value = x == y*y ? equal : (x < y*y ? greater : less) };
};
Идея в том, что квадрат числа, большего чем 0xffff даст число, большее чем умещается в переменной типа unsigned. Наша модификация заставляет sqrt_state расценивать число как слишком большое в случае если оно больше, чем 0xffff.
Снова компилируем и... снова ошибка. На этот раз одна. Речь идёт о строке:
static_assert<sqrt<0x10001ul*0xfffful>::value == 0xfffful>();
Что же здесь могло произойти? Реальная проблема находится здесь:
enum { value = (lo+1)*(lo+1) <= x ? (lo + 1) : lo };
lo у нас вполне может оказаться равным 0xffff, соответственно (0xffff+1)*(0xffff+1) это 0x10000*0x10000 - а это много. Проблема решается следующим волшебным образом:
enum { value = lo * lo + 2 * lo <= x - 1 ? (lo + 1) : lo };
В таком варианте переполнения не происходит, хотя исправленное условное выражение означает то же самое, что и исходное. Кстати, вопросы по поводу разницы между a+b+c и a+c+b иногда задают на собеседованиях :)
В качестве заключения: sqrt.