[LC 310] Minimum height trees

Calculate pairwise distance got an TLE.
An observation: iteratively find and remove leaves (node with degree 1) will leave us the solution! The left one or two nodes lie in the middle of the longest path in the graph.
pseudo code:
public List<Integer> findMinHeightTree(int n,  int[][] edges) {
    // store the graph in adjacency matrix, this is generalizable to all other graph problems with nodes 0 — n-1           
    List<Set<Integer>> adj = new ArrayList<Set<Integer>>();
    for  i = 0 to n-1    adj.add(new HashSet<Integer>());
    for  i = 0 to edges.length-1    adj.get(edges[i][0]).add(edges[i][1]); adj.get(edges[i][1]).add(edges[i][0]);

    // find all leaves
    List<Integer> leaves = new ArrayList<Integer>();
    for i = 0 to n-1, if adj.get(i).size() == 1, then leaves.add(i);
    // iteratively remove leaves until there is only one or two nodes
    while (n > 2){
        n = n – leaves.size();
        List<Integer> newleaves = new ArrayList<Integer>();
        for each i in leaves:
            int j = adj.get(i).iterator().next(); // get the nodes connected to this leave
            adj.get(j).remove(i); // remove the connection between the node and this leave
            if adj.get(j).size() == 1, then newleaves.add(j); // the node’s degree reduces one, it may be a new leave
        leaves = newleaves;

