模拟退火学习笔记

· · 个人记录

模拟退火学习笔记

定义

即 Simulated Annealing,SA。

模拟退火的原理也和金属退火的原理近似:将热力学的理论套用到统计学上,将搜寻空间内每一点想像成空气内的分子;分子的能量,就是它本身的动能;而搜寻空间内的每一点,也像空气分子一样带有“能量”,以表示该点对命题的合适程度。演算法先以搜寻空间内一个任意点作起始:每一步先选择一个“邻居”,然后再计算从现有位置到达“邻居”的概率。

——百度百科

算法流程

将系统状态设置为一个较高的温度 t。对于每次操作结束后,将当前温度设置为 t\times \delta(0< \delta< 1)\delta 即降温速率。\delta 越接近 1,计算次数越多,答案越准确,但同时也会导致消耗的时间较多。我们再设置一个末温 t_0,表示算法降温到 t_0 时退出。

对于每次操作前,根据当前温度 t 随机生成一个在当前答案 ans 附近的值 x。当 t 越大,x 可能的值越多。具体地,x=ans+\text{randint}(-k,k)\times t。其中 k 根据题意设置,需要在定义域之内,可以稍大一些。

计算当前值 f(x)f(ans) 相比较,进行更新。

具体地(假设要求 f 函数的最小值):

  1. 在实际运用中 $\text{randdouble}(0,1)$ 的实现都是 $\dfrac{\text{rand}()}{\texttt{RAND\_MAX}}$,而这样通常会导致较大的进度误差,所以普遍是将 $\texttt{RAND\_MAX}$ 移到左边,变成 $\min(1, \text{exp}(\dfrac{f(ans)-f(x)}{t}))\times\texttt{RAND\_MAX}\ge \text{rand()}$。 这是模拟退火的核心思想,在温度较高的时候,以一定概率接受较差的结果。这样可以很有效的退出局部最优解达到全局最优解。

大致过程如下:

可以看出随着温度下降,答案将集中于最优解上。

而模拟退火算法是一种随机化算法,通常一次并不能求出准确的答案,所以基本上不会运算一次,实际运用中运行次数会很多。故每次退火的时间复杂度要很低以保证较高的准确率(实际上与运气有很大的关系)。

如果你一直无法通过可以这样解决:

  1. 调大初温 t,增大降温速率 \delta,减小末温 t_0
  2. 多跑几遍
  3. 写正解

例题

[JSOI2004] 平衡点 / 吊打XXX

题目描述

如图,有 n 个重物,每个重物系在一条足够长的绳子上。

每条绳子自上而下穿过桌面上的洞,然后系在一起。图中 x 处就是公共的绳结。假设绳子是完全弹性的(即不会造成能量损失),桌子足够高(重物不会垂到地上),且忽略所有的摩擦,求绳结 x 最终平衡于何处。

注意:桌面上的洞都比绳结 x 小得多,所以即使某个重物特别重,绳结 x 也不可能穿过桌面上的洞掉下来,最多是卡在某个洞口处。

输入格式

文件的第一行为一个正整数 n1\le n\le 1000),表示重物和洞的数目。

接下来的 n 行,每行是 3 个整数 x_i, y_i, w_i,分别表示第 i 个洞的坐标以及第 i 个重物的重量。(-10000\le x_i,y_i\le10000, 0<w_i\le1000

输出格式

你的程序必须输出两个浮点数(保留小数点后三位),分别表示处于最终平衡状态时绳结 x 的横坐标和纵坐标。两个数以一个空格隔开。

样例输入

3
0 0 1
0 2 1
1 1 1

样例输出

0.577 1.000

题解

对于这道题而言,当平衡状态时,系统的重力势能即 \sum^n_{i=1} \text{dis}[i]\times \text{w}[i] 要等于 0。而用模拟退火很难实现等于 0,故要使其的值最小。

那么对于该算法而言,f(x)=\sum^n_{i=1} \text{dis}(x, i)\times \text{w}[i]。模拟实现即可。

要注意的几个点:

  1. 初温要尽可能大,我定义的是 5000
  2. 末温不要定义为 0,因为这将无法结束程序,尽量定义为一个极小值,我定义的是 10^{-15}
  3. 多次执行算法
#include<iostream>
#include<climits>
#include<cstring>
#include<cstdio>
#include<iomanip>
#include<cmath>
using namespace std;
inline int read() {
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch>'9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - 48; ch = getchar(); }
    return x * f;
}

const int maxn = 1e3 + 10;

int n;
struct point {
    double x, y, w;
} p[maxn];

double pow2(double x) { return x * x; }
double dis(point a, point b) {
    return sqrt(pow2(a.x - b.x) + pow2(a.y - b.y));
}
double f(point a) {
    double res = 0;
    for (int i = 1; i <= n; i++) {
        res += (dis(p[i], a) * p[i].w);
    }
    return res;
}

point ans;

const int range = 20000;

void SA() {
    double t = 5000.0;
    double t0 = 1e-15;
    double delta = 0.996;
    while (t > t0) {
        point now;
        now.x = ans.x + (RAND_MAX - rand() - rand()) * t;
        now.y = ans.y + (RAND_MAX - rand() - rand()) * t;
        now.w = f(now);
        if (now.w < ans.w) {
            ans = now;
        }
        else if (min(1.0, exp((ans.w - now.w) / t)) * RAND_MAX > rand()) {
            ans = now;
        }
        t *= delta;
    }
}

int main() {
    n = read();
    for (int i = 1; i <= n; i++) {
        cin >> p[i].x >> p[i].y >> p[i].w;
        ans.x += p[i].x;
        ans.y += p[i].y;
    }
    ans.x /= n;
    ans.y /= n;
    ans.w = f(ans);
    int times = 10;
    while (times--) {
        SA();
    }
    cout << fixed << setprecision(3) << ans.x << ' ' << ans.y << endl;
    return 0;
}