定义 dp(col, pos) 表示从 col 列开始匹配 target[:pos+1] 的方案数。那么答案就是 dp(0, 0)
target[:pos+1] 表示从索引 0 到索引 pos 的 target 切片
接下来我们考虑如何转移:
对于每一个 col, 我们可以选择匹配或者不匹配。
如果匹配, 那么需要满足 word[col] == target[pos]
将匹配和不匹配的方案数累加记为答案。
class Solution:
def numWays(self, words: List[str], target: str) -> int:
MOD = 10 ** 9 + 7
k = len(words[0])
cnt = [[0] * k for _ in range(26)]
for j in range(k):
for word in words:
cnt[ord(word[j]) - ord('a')][j] += 1
@cache
def dp(col, pos):
if len(target) - pos > len(words[0]) - col: return 0 # 剪枝
if pos == len(target): return 1
if col == len(words[0]): return 0
ans = dp(col+1, pos) # skip
for word in words: # pick one of the word[col]
if word[col] == target[pos]:
ans += dp(col+1, pos+1)
ans %= MOD
return ans % MOD
return dp(0, 0) % MOD
另外 m 为 words 长度, k 为 word 长度, n 为 target 长度。
那么复杂度为保底的 DP 复杂度 n * k,再乘以 dp 内部转移的复杂度为 m,因此复杂度为 $O(m * n * k)$,代入题目的数据范围, 可以达到 10 ** 9, 无法通过。
class Solution:
def numWays(self, words: List[str], target: str) -> int:
MOD = 10 ** 9 + 7
k = len(words[0])
cnt = [[0] * k for _ in range(26)]
for j in range(k):
for word in words:
cnt[ord(word[j]) - ord('a')][j] += 1
@cache
def dp(col, pos):
if len(target) - pos > len(words[0]) - col: return 0 # 剪枝
if pos == len(target): return 1
if col == len(words[0]): return 0
ans = dp(col+1, pos) # skip
ans += dp(col+1, pos+1) * cnt[ord(target[pos]) - ord('a')][col] # 根据上面的提示,我们可以这样优化
return ans % MOD
return dp(0, 0) % MOD