Advertisement
Guest User

Untitled

a guest
Oct 12th, 2021
123
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.40 KB | None | 0 0
  1. import math
  2. import os
  3. import sys
  4. from io import BytesIO, IOBase
  5. import threading
  6. M = 1000000007
  7. import random
  8. import heapq
  9. import bisect
  10. import time
  11.  
  12. sys.setrecursionlimit(10 ** 5+50)
  13. from functools import *
  14. from collections import *
  15. from itertools import *
  16.  
  17. BUFSIZE = 8192
  18. import array
  19.  
  20.  
  21. class FastIO(IOBase):
  22.     newlines = 0
  23.  
  24.     def __init__(self, file):
  25.         self._fd = file.fileno()
  26.         self.buffer = BytesIO()
  27.         self.writable = "x" in file.mode or "r" not in file.mode
  28.         self.write = self.buffer.write if self.writable else None
  29.  
  30.     def read(self):
  31.         while True:
  32.             b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
  33.             if not b:
  34.                 break
  35.             ptr = self.buffer.tell()
  36.             self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
  37.         self.newlines = 0
  38.         return self.buffer.read()
  39.  
  40.     def readline(self):
  41.         while self.newlines == 0:
  42.             b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
  43.             self.newlines = b.count(b"\n") + (not b)
  44.             ptr = self.buffer.tell()
  45.             self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
  46.         self.newlines -= 1
  47.         return self.buffer.readline()
  48.  
  49.     def flush(self):
  50.         if self.writable:
  51.             os.write(self._fd, self.buffer.getvalue())
  52.             self.buffer.truncate(0), self.buffer.seek(0)
  53.  
  54.  
  55. class IOWrapper(IOBase):
  56.     def __init__(self, file):
  57.         self.buffer = FastIO(file)
  58.         self.flush = self.buffer.flush
  59.         self.writable = self.buffer.writable
  60.         self.write = lambda s: self.buffer.write(s.encode("ascii"))
  61.         self.read = lambda: self.buffer.read().decode("ascii")
  62.         self.readline = lambda: self.buffer.readline().decode("ascii")
  63.  
  64.  
  65. def print(*args, **kwargs):
  66.     sep, file = kwargs.pop("sep", " "), kwargs.pop("file", sys.stdout)
  67.     at_start = True
  68.     for x in args:
  69.         if not at_start:
  70.             file.write(sep)
  71.         file.write(str(x))
  72.         at_start = False
  73.     file.write(kwargs.pop("end", "\n"))
  74.     if kwargs.pop("flush", False):
  75.         file.flush()
  76.  
  77.  
  78. if sys.version_info[0] < 3:
  79.     sys.stdin, sys.stdout = FastIO(sys.stdin), FastIO(sys.stdout)
  80. else:
  81.     sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout)
  82. input = lambda: sys.stdin.readline().rstrip("\r\n")
  83.  
  84.  
  85. def inp(): return sys.stdin.readline().rstrip("\r\n")  # for fast input
  86.  
  87.  
  88. def out(var): sys.stdout.write(str(var))  # for fast output, always take string
  89.  
  90.  
  91. def lis(): return list(map(int, inp().split()))
  92.  
  93.  
  94. def stringlis(): return list(map(str, inp().split()))
  95.  
  96.  
  97. def sep(): return map(int, inp().split())
  98.  
  99.  
  100. def strsep(): return map(str, inp().split())
  101.  
  102.  
  103. def fsep(): return map(float, inp().split())
  104.  
  105.  
  106. def inpu(): return int(inp())
  107.  
  108.  
  109. def build(arr, a, b, st, node):
  110.     if a == b:
  111.         st[node] = arr[a]
  112.         return st[node]
  113.     mid = (a + b) // 2
  114.     st[node] = min(build(arr, a, mid, st, 2 * node + 1), build(arr, mid + 1, b, st, 2 * node + 2))
  115.     return st[node]
  116.  
  117.  
  118. def getmin(arr, a, b, l, r, st, node):
  119.     if l > b or r < a:
  120.         return float("inf")
  121.     if l <= a and r >= b:
  122.         return st[node]
  123.     mid = (a + b) // 2
  124.     return min(getmin(arr, a, mid, l, r, st, 2 * node + 1), getmin(arr, mid + 1, b, l, r, st, 2 * node + 2))
  125.  
  126.  
  127. def dfs(cur, visited, d, res):
  128.     visited[cur] = True
  129.     for i in d[cur]:
  130.         if not visited[i]:
  131.             res[0] += 1
  132.             dfs(i, visited, d, res)
  133.  
  134.  
  135. def main():
  136.     t = 1
  137.     # t=int(input())
  138.     for _ in range(t):
  139.         n = inpu()
  140.         arr = lis()
  141.         d = defaultdict(list)
  142.         for i in range(n):
  143.             d[i].append(arr[i] - 1)
  144.             d[arr[i] - 1].append(i)
  145.         visited = [False] * (n + 1)
  146.         ans = []
  147.         for i in range(n):
  148.             res = [1]
  149.             if not visited[i]:
  150.                 dfs(i, visited, d, res)
  151.                 ans.append(res[0])
  152.         ans.sort()
  153.         if len(ans) == 1:
  154.             print(ans[0] * ans[0])
  155.         elif len(ans) > 1:
  156.             c = ans[-1] + ans[-2]
  157.             res1 = c * c
  158.             for i in range(len(ans) - 2):
  159.                 res1 += (ans[i] * ans[i])
  160.             print(res1)
  161.         else:
  162.             print(0)
  163.  
  164.  
  165. if __name__ == '__main__':
  166.     threading.stack_size(2*10**8)
  167.     threading.Thread(target=main).start()
  168.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement