Sort a linked list in O(n log n) time using constant space complexity.

Example

List       Return
null       null
1->2       1->2
1->3->2    1->2->3
3->2->1    1->2->3

The following is an implementation of merge sort with O(n log n) time complexity.

public class Solution {
    /**
     * @param head: The head of linked list.
     * @return: Head of the sorted linked list.
     */
    public ListNode sortList(ListNode head) {  
        return mergeSort(head);
    }
    
    ListNode mergeSort(ListNode head) {
        int len = getLength(head);
        if (len <= 1) return head;
        
        int mid = len / 2;
        
        // Find the head of second half
        ListNode half = head; //Head of second half
        ListNode tail = null; //Last node of first half
        
        for (int i = 0; i < mid; i++) {
            tail = half;
            half = half.next;
        }
        
        // Divide the two halves
        tail.next = null;
        
        // Merge sort both halves
        head = mergeSort(head);
        half = mergeSort(half);
        
        // Merge
        ListNode res = null, end = null, next = null;
        
        while (head != null || half != null) {
            if (head == null) {
                next = half;
                half = half.next;
            }
            else if (half == null) {
                next = head;
                head = head.next;
            }
            else if (head.val < half.val) {
                next = head;
                head = head.next;
            }
            else {
                next = half;
                half = half.next;
            }
            
            if (res == null)
                res = end = next;
            else
                end = end.next = next;
        }
        
        return res;
    }
    
    /**
     * Gets the very last node of the list
     */
    ListNode getTail(ListNode head) {
        while (head.next != null)
            head = head.next;
        return head;
    }
    
    /**
     * Gets the length of the list
     */
    int getLength(ListNode head) {
        int len = 0;
        while (head != null) {
            len++;
            head = head.next;
        }
        return len;
    }
}
public class ListNode {
    int val;
    ListNode next;
    ListNode(int val) {
        this.val = val;
        this.next = null;
    }
}