Skip to content
Snippets Groups Projects

Mergesort: recursive vs iterative

  • Clone with SSH
  • Clone with HTTPS
  • Embed
  • Share
    The snippet can be accessed without any authentication.
    Authored by Kaspar Lutter

    A brief benchmark to compare the performance of the different mergesort variants

    Edited
    main.cpp 5.57 KiB
    #include <algorithm>
    #include <cassert>
    #include <chrono>
    #include <functional>
    #include <iostream>
    #include <random>
    #include <vector>
    
    void merge(std::vector<int> &vec, int left, int mid, int right) {
      std::vector<int> leftVec(mid - left), rightVec(right - mid);
    
      std::copy(vec.begin() + left, vec.begin() + mid, leftVec.begin());
      std::copy(vec.begin() + mid, vec.begin() + right, rightVec.begin());
      std::merge(rightVec.begin(), rightVec.end(), leftVec.begin(), leftVec.end(),
                 vec.begin() + left);
    }
    
    void insertionSort(std::vector<int> &vec, int left, int right) {
      for (int i = left + 1; i < right; ++i) {
        int key = vec[i];
        int j = i - 1;
    
        while (j >= left && vec[j] > key) {
          vec[j + 1] = vec[j];
          j = j - 1;
        }
        vec[j + 1] = key;
      }
    }
    
    void mergeSortRecursive(std::vector<int> &vec, int left, int right,
                            bool useInsertionSort) {
      if (right - left <= 1)
        return;
    
      if (useInsertionSort && right - left < 48) {
        insertionSort(vec, left, right);
        return;
      }
    
      int mid = left + (right - left) / 2;
    
      mergeSortRecursive(vec, left, mid, useInsertionSort);
      mergeSortRecursive(vec, mid, right, useInsertionSort);
      merge(vec, left, mid, right);
    }
    
    void mergeSortIterative(std::vector<int> &vec, bool useInsertionSort) {
      int n = vec.size();
    
      int insertion_sort_size = 12;
      int start_size = useInsertionSort ? insertion_sort_size : 1;
    
      if (start_size >= n)
        insertionSort(vec, 0, n);
    
      for (int curr_size = start_size; curr_size < n; curr_size = 2 * curr_size) {
        for (int left_start = 0; left_start < n - 1; left_start += 2 * curr_size) {
          int mid = std::min(left_start + curr_size, n - 1);
          int right_end = std::min(left_start + 2 * curr_size, n);
          if (useInsertionSort && curr_size == insertion_sort_size) {
            insertionSort(vec, left_start, right_end);
          } else {
            merge(vec, left_start, mid, right_end);
          }
        }
      }
    }
    
    std::chrono::duration<double> measure(std::function<void()> t) {
      auto start = std::chrono::steady_clock::now();
      t();
      auto stop = std::chrono::steady_clock::now();
      auto elapsed = std::chrono::duration<double>(stop - start);
      return elapsed;
    }
    
    // etwa 2^i, aber genau 2^i ist blöd weils zu genau mit
    // cache sizes übereinstimmt und daher zu große fehler macht.
    int calculate_data_size(int i) {
      return (1 << i) + (1 << (i >> 1)) + (1 << (i >> 3));
    }
    
    int main() {
    
      std::mt19937 gen{};
      std::uniform_int_distribution<> dist(-1'000'000'000, 1'000'000'000);
    
      for (int test_run = 0; test_run < 4; test_run++) {
        std::cout << "n,rekursiv" << (test_run > 1 ? " + insertionsort" : "")
                  << ",iterativ" << (test_run > 1 ? " + insertionsort" : "")
                  << ",rekursiv vs iterativ" << std::endl;
    
        for (int i = 8; i < 25; i++) {
    
          int data_size = calculate_data_size(i);
    
          std::vector<int> buff(data_size);
    
          std::generate(buff.begin(), buff.end(), [&]() { return dist(gen); });
    
          auto elapseda = measure(
              [&]() { mergeSortRecursive(buff, 0, data_size, test_run > 1); });
    
          assert((void("recursive sort failed"),
                  std::is_sorted(buff.begin(), buff.end())));
    
          std::generate(buff.begin(), buff.end(), [&]() { return dist(gen); });
    
          auto elapsedb =
              measure([&]() { mergeSortIterative(buff, test_run > 1); });
    
          assert((void("iterative sort failed"),
                  std::is_sorted(buff.begin(), buff.end())));
    
          auto microsa =
              std::chrono::duration_cast<std::chrono::microseconds>(elapseda);
          auto microsb =
              std::chrono::duration_cast<std::chrono::microseconds>(elapsedb);
    
          std::cout.precision(3);
    
          std::cout << data_size << "," << microsa.count() << "," << microsb.count()
                    << "," << (elapseda / elapsedb) << std::endl;
        }
      }
    
      // tests for small arrays
    
      for (int test_run = 0; test_run < 4; test_run++) {
        std::cout << "n,rekursiv" << (test_run > 1 ? " + insertionsort" : "")
                  << ",iterativ" << (test_run > 1 ? " + insertionsort" : "")
                  << ",rekursiv vs iterativ" << std::endl;
    
        const size_t small_cases = 50'000;
    
        for (int i = 3; i < 11; i++) {
          std::array<std::vector<int>, small_cases> cases{};
          int data_size = calculate_data_size(i);
    
          // setup tests
          for (auto &_case : cases) {
            _case.resize(data_size, 0);
            std::generate(_case.begin(), _case.end(), [&]() { return dist(gen); });
          }
    
          // measure
          auto elapseda = measure([&]() {
            for (auto &buff : cases)
              mergeSortRecursive(buff, 0, data_size, test_run > 1);
          });
    
          // verify results
          for (auto &_case : cases) {
            assert((void("recursive sort failed"),
                    std::is_sorted(_case.begin(), _case.end())));
          }
    
          // setup
          for (auto &_case : cases) {
            _case.resize(data_size, 0);
            std::generate(_case.begin(), _case.end(), [&]() { return dist(gen); });
          }
    
          // measure
          auto elapsedb = measure([&]() {
            for (auto &buff : cases)
              mergeSortIterative(buff, test_run > 1);
          });
    
          // verify
          for (auto &_case : cases) {
            assert((void("iterative sort failed"),
                    std::is_sorted(_case.begin(), _case.end())));
          }
    
          auto microsa =
              std::chrono::duration_cast<std::chrono::microseconds>(elapseda);
          auto microsb =
              std::chrono::duration_cast<std::chrono::microseconds>(elapsedb);
    
          std::cout.precision(3);
    
          std::cout << data_size << "," << microsa.count() << "," << microsb.count()
                    << "," << (elapseda / elapsedb) << std::endl;
        }
      }
    }
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment