본문 바로가기

코딩/BOJ

BOJ 4013 : ATM

문제 링크 : https://www.acmicpc.net/problem/4013

 

4013번: ATM

첫째 줄에 교차로의 수와 도로의 수를 나타내는 2개의 정수 N과 M(N, M ≤ 500,000)이 차례로 주어진다. 교차로는 1부터 N까지 번호로 표시된다. 그 다음 M개의 줄에는 각 줄마다 각 도로의 시작 교차

www.acmicpc.net

 

SCC 응용 문제입니다. 간단해 보이지만 처리해줄게 많고 구현량이 많아 생각보다 어려웠습니다.

 

문제를 간략하게 정리하면, 그래프가 주어지며 각 정점은 현금을 가지고 있습니다. 이때, 시작 정점과 도착가능한 정점이 주어질 때 인출할 수 있는 현금의 최댓값을 출력해야 합니다.

현금의 액수는 0이상의 정수이기 때문에 하나의 SCC에서는 모든 정점의 현금을 인출하는 것이 이득일 것입니다. 따라서 다른 과정을 거치지 전에 먼저 SCC를 구하고, 각 SCC에 속하는 정점들이 가진 현금의 합을 배열에 정리하였습니다.

이후, 정점과 간선으로 이루어진 그래프를 SCC에 대한 그래프(DAG)로 변환해주었습니다. 모든 간선을 확인하며 간선의 양쪽에 있는 정점이 다른 SCC에 속해있다면 set에 insert 해주었습니다.

SCC에 대한 그래프를 만들었다면 이 그래프에 대해 트리DP를 돌려주면 됩니다. 트리 DP의 경우 이전 정점까지의 DP값 + 현재 정점의 현금의 합(현재 SCC의 현금의 합)으로 업데이트를 해주면 됩니다. 만약, 이미 방문한 정점에 다시 방문했을 경우, DP값에 업데이트가 생기는 경우에는 계속 방문해주어도 됩니다.

최종적으로 도착가능한 정점(레스토랑이 있는 교차로)들이 속한 DP값들을 비교하여 최댓값을 출력해주면 됩니다. 

#include <bits/stdc++.h>
using namespace std;

int n, m, cnt = 0, indegvisit[505050] = {0}, val[505050] = {0}, s, p, r[505050] = {0}, chk[505050] = {0}, scc[505050] = {0};
vector<int> v[505050], r_v[505050], st, indeg[505050], c[505050], visits;
set<int> tmp[505050];
long long scchap[505050] = {0}, dp[505050] = {0};

void dfs(int cur) {
   chk[cur] = 1;
   sort(v[cur].begin(), v[cur].end());
   for (int i = 0; i < v[cur].size(); i++) {
      if(!chk[v[cur][i]]) dfs(v[cur][i]);
   }
   st.push_back(cur);
}

void rdfs(int cur) {
   scc[cur] = cnt;
   sort(r_v[cur].begin(), r_v[cur].end());
   for (int i = 0; i < r_v[cur].size(); i++) {
      if(!scc[r_v[cur][i]]) rdfs(r_v[cur][i]);
   }
}

void sccindeg(int cur) {
   visits = vector<int>(cnt+1, 0);
   visits[cur] = 1;
   for (int i = 0; i < c[cur].size(); i++) {
      for (int j = 0; j < v[c[cur][i]].size(); j++) {
         if(!visits[scc[v[c[cur][i]][j]]]) {
            indeg[cur].push_back(scc[v[c[cur][i]][j]]);
            visits[scc[v[c[cur][i]][j]]] = 1;
         }
      }
   }
}

void sdfs(int cur) {
   indegvisit[cur] = 1;
   for (int i = 0; i < indeg[cur].size(); i++) {
      if(indegvisit[indeg[cur][i]]) {
         if(dp[cur]+scchap[indeg[cur][i]] > dp[indeg[cur][i]]) {
            dp[indeg[cur][i]] = dp[cur]+scchap[indeg[cur][i]];
            sdfs(indeg[cur][i]);
         }
      }
      else {
         dp[indeg[cur][i]] = dp[cur]+scchap[indeg[cur][i]];
         sdfs(indeg[cur][i]);
      }
   }
}

int main() {
   scanf("%d %d", &n, &m);
   for (int i = 1; i <= m; i++) {
      int x, y;
      scanf("%d %d", &x, &y);
      v[x].push_back(y);
      r_v[y].push_back(x);
   }
   for (int i = 1; i <= n; i++) scanf("%d", &val[i]);
   scanf("%d %d", &s, &p);
   for (int i = 1; i <= p; i++){
      scanf("%d", &r[i]);
   }

   for (int i = 1; i <= n; i++) {
      if(!chk[i]) dfs(i);
   }
   for (int i = st.size()-1; i >= 0; i--) {
      if(!scc[st[i]]) {cnt++; rdfs(st[i]);}
   }
   for (int i = 1; i <= n; i++) {
      scchap[scc[i]] += val[i];
      c[scc[i]].push_back(i);
   }
   for (int i = 1; i <= n; i++) {
      for (int j = 0; j < v[i].size(); j++) {
         if(scc[i] != scc[v[i][j]]) {
            tmp[scc[i]].insert(scc[v[i][j]]);
         }
      }
   }

   for (int i = 1; i <= cnt; i++) {
      for (auto e : tmp[i]) {
         indeg[i].push_back(e);
      }
   }
   dp[scc[s]] = scchap[scc[s]];
   sdfs(scc[s]);
   long long ans = 0;
   for (int i = 1; i <= p; i++) {
      ans = max(ans, dp[scc[r[i]]]);
   }
   printf("%lld", ans);
}