[SWEA] 3000. 중간값 구하기 (C++, 라이브러리 X)


Heap 기본문제

풀이

Max-Heap과 Min-Heap을 이용해서 중간값을 실시간으로 구할 수 있습니다.

Max-Heap과 Min-Heap을 따로 구현해주었습니다.

MinHeap에는 MaxHeap[0]보다 큰 값들만 넣어주고 작은 값들은 MaxHeap에 넣어줍니다.

Add와 Pop을 적절히 이용해 MinHeap의 크기 + 1 = MaxHeap의 크기가 되도록 하였습니다.

MaxHeap[0]에는 항상 중간값이 저장됩니다.

참고: 중앙값(Median) 찾기

코드 GitHub

#include <iostream>
using namespace std;

#define MAXN 200000

int N;

int maxHeap[MAXN + 2];
int minHeap[MAXN + 1];

int maxCnt, minCnt;

void AddMax(int val) {
    int pivot = maxCnt;
    maxHeap[maxCnt++] = val;
    
    while(pivot > 0) {
        int ppivot = pivot;
        pivot = pivot - 1 >> 1;
        if(maxHeap[pivot] < val) {
            maxHeap[ppivot] = maxHeap[pivot];
            maxHeap[pivot] = val;
        } else return;
    }
}

void AddMin(int val) {
    int pivot = minCnt;
    minHeap[minCnt++] = val;
    
    while(pivot > 0) {
        int ppivot = pivot;
        pivot = pivot - 1 >> 1;
        if(minHeap[pivot] > val) {
            minHeap[ppivot] = minHeap[pivot];
            minHeap[pivot] = val;
        } else return;
    }
}

int PopMax() {
    int val = maxHeap[0];
    maxHeap[0] = maxHeap[--maxCnt];

    int pivot = 0;
    while(true) {
        int ppivot = pivot;
        pivot = (pivot << 1) + 1;

        if(maxCnt <= pivot) break;
        if(maxCnt == pivot + 1) {
            if(maxHeap[pivot] > maxHeap[ppivot]) {
                int tmp = maxHeap[pivot];
                maxHeap[pivot] = maxHeap[ppivot];
                maxHeap[ppivot] = tmp;
            } else break;
        } else {
            if(maxHeap[pivot] < maxHeap[pivot + 1]) pivot++;
            if(maxHeap[pivot] > maxHeap[ppivot]) {
                int tmp = maxHeap[pivot];
                maxHeap[pivot] = maxHeap[ppivot];
                maxHeap[ppivot] = tmp;
            } else break;
        }
    }
    return val;
}

int PopMin() {
    int val = minHeap[0];
    minHeap[0] = minHeap[--minCnt];

    int pivot = 0;
    while(true) {
        int ppivot = pivot;
        pivot = (pivot << 1) + 1;

        if(minCnt <= pivot) break;
        if(minCnt == pivot + 1) {
            if(minHeap[pivot] < minHeap[ppivot]) {
                int tmp = minHeap[pivot];
                minHeap[pivot] = minHeap[ppivot];
                minHeap[ppivot] = tmp;
            } else break;
        } else {
            if(minHeap[pivot] > minHeap[pivot + 1]) pivot++;
            if(minHeap[pivot] < minHeap[ppivot]) {
                int tmp = minHeap[pivot];
                minHeap[pivot] = minHeap[ppivot];
                minHeap[ppivot] = tmp;
            } else break;
        }
    }
    return val;
}


int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int T;
    cin >> T;
    for(int tc = 1; tc <=T; tc++) {
        int A;
        cin >> N >> A;
        maxCnt = minCnt = 0;
        AddMax(A);
        int sum = 0;
        for(int n = 0; n < N; n++) {
            int X, Y;
            cin >> X >> Y;
            if(X > maxHeap[0]) AddMin(X);
            else AddMax(X);
            if(Y > maxHeap[0]) AddMin(Y);
            else AddMax(Y);
            if(maxCnt < minCnt) AddMax(PopMin());
            else if(maxCnt - 1 > minCnt) AddMin(PopMax());
            sum += maxHeap[0];
            if(sum >= 20171109) sum %= 20171109;
        }
        cout << '#' << tc << ' ' << sum << '\n';
    }
}

Reference

댓글남기기