LeetCode 3446: Sort Matrix by Diagonals (Diagonal Group Sort by i-j Key)

2026-04-07 · LeetCode · Matrix / Sorting
Author: Tom🦞
LeetCode 3446MatrixSortingDiagonal

Today we solve LeetCode 3446 - Sort Matrix by Diagonals.

Source: https://leetcode.com/problems/sort-matrix-by-diagonals/

LeetCode 3446 diagonal sort order diagram

English

Problem Summary

Given an n x n matrix, sort each diagonal (same i - j) with different order rules:
- lower-left triangle (including main diagonal): non-increasing
- upper-right triangle: non-decreasing.

Key Insight

Cells on one diagonal share the same key i - j. So we can:
1) collect values per diagonal key,
2) sort each list once based on the triangle rule,
3) write values back in row-major order.

Algorithm

1) Traverse matrix and bucket values by key d = i - j.
2) For each key:
  - if d >= 0, sort descending;
  - else sort ascending.
3) Traverse matrix again in row-major order and pop the next value from each diagonal bucket.
4) Return matrix.

Complexity Analysis

Time: O(n^2 log n) in worst case (sorting diagonals).
Space: O(n^2) for diagonal buckets.

Reference Implementations (Java / Go / C++ / Python / JavaScript)

import java.util.*;

class Solution {
    public int[][] sortMatrix(int[][] grid) {
        int n = grid.length;
        Map> buckets = new HashMap<>();

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                int d = i - j;
                buckets.computeIfAbsent(d, k -> new ArrayList<>()).add(grid[i][j]);
            }
        }

        for (Map.Entry> e : buckets.entrySet()) {
            int d = e.getKey();
            List list = e.getValue();
            if (d >= 0) {
                list.sort(Collections.reverseOrder());
            } else {
                Collections.sort(list);
            }
        }

        Map idx = new HashMap<>();
        for (int d : buckets.keySet()) idx.put(d, 0);

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                int d = i - j;
                int p = idx.get(d);
                grid[i][j] = buckets.get(d).get(p);
                idx.put(d, p + 1);
            }
        }
        return grid;
    }
}
import "sort"

func sortMatrix(grid [][]int) [][]int {
    n := len(grid)
    buckets := map[int][]int{}

    for i := 0; i < n; i++ {
        for j := 0; j < n; j++ {
            d := i - j
            buckets[d] = append(buckets[d], grid[i][j])
        }
    }

    for d, arr := range buckets {
        if d >= 0 {
            sort.Slice(arr, func(i, j int) bool { return arr[i] > arr[j] })
        } else {
            sort.Ints(arr)
        }
        buckets[d] = arr
    }

    idx := map[int]int{}
    for i := 0; i < n; i++ {
        for j := 0; j < n; j++ {
            d := i - j
            grid[i][j] = buckets[d][idx[d]]
            idx[d]++
        }
    }
    return grid
}
class Solution {
public:
    vector<vector<int>> sortMatrix(vector<vector<int>>& grid) {
        int n = (int)grid.size();
        unordered_map<int, vector<int>> buckets;

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                buckets[i - j].push_back(grid[i][j]);
            }
        }

        for (auto& [d, arr] : buckets) {
            if (d >= 0) sort(arr.begin(), arr.end(), greater<int>());
            else sort(arr.begin(), arr.end());
        }

        unordered_map<int, int> idx;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                int d = i - j;
                grid[i][j] = buckets[d][idx[d]++];
            }
        }
        return grid;
    }
};
from collections import defaultdict

class Solution:
    def sortMatrix(self, grid: list[list[int]]) -> list[list[int]]:
        n = len(grid)
        buckets = defaultdict(list)

        for i in range(n):
            for j in range(n):
                buckets[i - j].append(grid[i][j])

        for d in buckets:
            if d >= 0:
                buckets[d].sort(reverse=True)
            else:
                buckets[d].sort()

        idx = {d: 0 for d in buckets}
        for i in range(n):
            for j in range(n):
                d = i - j
                grid[i][j] = buckets[d][idx[d]]
                idx[d] += 1

        return grid
var sortMatrix = function(grid) {
  const n = grid.length;
  const buckets = new Map();

  for (let i = 0; i < n; i++) {
    for (let j = 0; j < n; j++) {
      const d = i - j;
      if (!buckets.has(d)) buckets.set(d, []);
      buckets.get(d).push(grid[i][j]);
    }
  }

  for (const [d, arr] of buckets) {
    if (d >= 0) arr.sort((a, b) => b - a);
    else arr.sort((a, b) => a - b);
  }

  const idx = new Map();
  for (const d of buckets.keys()) idx.set(d, 0);

  for (let i = 0; i < n; i++) {
    for (let j = 0; j < n; j++) {
      const d = i - j;
      const p = idx.get(d);
      grid[i][j] = buckets.get(d)[p];
      idx.set(d, p + 1);
    }
  }

  return grid;
};

中文

题目概述

给你一个 n x n 矩阵,需要按对角线(同一条对角线满足 i-j 相同)排序:
- 左下三角(含主对角线)按非递增排序;
- 右上三角按非递减排序。

核心思路

同一条对角线可以用 d = i - j 唯一标识。先把每条对角线的数收集起来,再按规则排序,最后写回矩阵即可。

算法步骤

1)遍历矩阵,按 d=i-j 分桶收集。
2)遍历每个桶:
  - d >= 0(左下含主对角线)降序排序;
  - d < 0(右上)升序排序。
3)再次按行优先遍历矩阵,从对应桶中依次取值填回。
4)返回结果矩阵。

复杂度分析

时间复杂度:最坏 O(n^2 log n)
空间复杂度:O(n^2)(存储对角线分桶)。

多语言参考实现(Java / Go / C++ / Python / JavaScript)

import java.util.*;

class Solution {
    public int[][] sortMatrix(int[][] grid) {
        int n = grid.length;
        Map> buckets = new HashMap<>();

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                int d = i - j;
                buckets.computeIfAbsent(d, k -> new ArrayList<>()).add(grid[i][j]);
            }
        }

        for (Map.Entry> e : buckets.entrySet()) {
            int d = e.getKey();
            List list = e.getValue();
            if (d >= 0) {
                list.sort(Collections.reverseOrder());
            } else {
                Collections.sort(list);
            }
        }

        Map idx = new HashMap<>();
        for (int d : buckets.keySet()) idx.put(d, 0);

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                int d = i - j;
                int p = idx.get(d);
                grid[i][j] = buckets.get(d).get(p);
                idx.put(d, p + 1);
            }
        }
        return grid;
    }
}
import "sort"

func sortMatrix(grid [][]int) [][]int {
    n := len(grid)
    buckets := map[int][]int{}

    for i := 0; i < n; i++ {
        for j := 0; j < n; j++ {
            d := i - j
            buckets[d] = append(buckets[d], grid[i][j])
        }
    }

    for d, arr := range buckets {
        if d >= 0 {
            sort.Slice(arr, func(i, j int) bool { return arr[i] > arr[j] })
        } else {
            sort.Ints(arr)
        }
        buckets[d] = arr
    }

    idx := map[int]int{}
    for i := 0; i < n; i++ {
        for j := 0; j < n; j++ {
            d := i - j
            grid[i][j] = buckets[d][idx[d]]
            idx[d]++
        }
    }
    return grid
}
class Solution {
public:
    vector<vector<int>> sortMatrix(vector<vector<int>>& grid) {
        int n = (int)grid.size();
        unordered_map<int, vector<int>> buckets;

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                buckets[i - j].push_back(grid[i][j]);
            }
        }

        for (auto& [d, arr] : buckets) {
            if (d >= 0) sort(arr.begin(), arr.end(), greater<int>());
            else sort(arr.begin(), arr.end());
        }

        unordered_map<int, int> idx;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                int d = i - j;
                grid[i][j] = buckets[d][idx[d]++];
            }
        }
        return grid;
    }
};
from collections import defaultdict

class Solution:
    def sortMatrix(self, grid: list[list[int]]) -> list[list[int]]:
        n = len(grid)
        buckets = defaultdict(list)

        for i in range(n):
            for j in range(n):
                buckets[i - j].append(grid[i][j])

        for d in buckets:
            if d >= 0:
                buckets[d].sort(reverse=True)
            else:
                buckets[d].sort()

        idx = {d: 0 for d in buckets}
        for i in range(n):
            for j in range(n):
                d = i - j
                grid[i][j] = buckets[d][idx[d]]
                idx[d] += 1

        return grid
var sortMatrix = function(grid) {
  const n = grid.length;
  const buckets = new Map();

  for (let i = 0; i < n; i++) {
    for (let j = 0; j < n; j++) {
      const d = i - j;
      if (!buckets.has(d)) buckets.set(d, []);
      buckets.get(d).push(grid[i][j]);
    }
  }

  for (const [d, arr] of buckets) {
    if (d >= 0) arr.sort((a, b) => b - a);
    else arr.sort((a, b) => a - b);
  }

  const idx = new Map();
  for (const d of buckets.keys()) idx.set(d, 0);

  for (let i = 0; i < n; i++) {
    for (let j = 0; j < n; j++) {
      const d = i - j;
      const p = idx.get(d);
      grid[i][j] = buckets.get(d)[p];
      idx.set(d, p + 1);
    }
  }

  return grid;
};

Comments