fmin.js 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. (function (global, factory) {
  2. typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) :
  3. typeof define === 'function' && define.amd ? define(['exports'], factory) :
  4. (factory((global.fmin = global.fmin || {})));
  5. }(this, function (exports) { 'use strict';
  6. /** finds the zeros of a function, given two starting points (which must
  7. * have opposite signs */
  8. function bisect(f, a, b, parameters) {
  9. parameters = parameters || {};
  10. var maxIterations = parameters.maxIterations || 100,
  11. tolerance = parameters.tolerance || 1e-10,
  12. fA = f(a),
  13. fB = f(b),
  14. delta = b - a;
  15. if (fA * fB > 0) {
  16. throw "Initial bisect points must have opposite signs";
  17. }
  18. if (fA === 0) return a;
  19. if (fB === 0) return b;
  20. for (var i = 0; i < maxIterations; ++i) {
  21. delta /= 2;
  22. var mid = a + delta,
  23. fMid = f(mid);
  24. if (fMid * fA >= 0) {
  25. a = mid;
  26. }
  27. if ((Math.abs(delta) < tolerance) || (fMid === 0)) {
  28. return mid;
  29. }
  30. }
  31. return a + delta;
  32. }
  33. // need some basic operations on vectors, rather than adding a dependency,
  34. // just define here
  35. function zeros(x) { var r = new Array(x); for (var i = 0; i < x; ++i) { r[i] = 0; } return r; }
  36. function zerosM(x,y) { return zeros(x).map(function() { return zeros(y); }); }
  37. function dot(a, b) {
  38. var ret = 0;
  39. for (var i = 0; i < a.length; ++i) {
  40. ret += a[i] * b[i];
  41. }
  42. return ret;
  43. }
  44. function norm2(a) {
  45. return Math.sqrt(dot(a, a));
  46. }
  47. function scale(ret, value, c) {
  48. for (var i = 0; i < value.length; ++i) {
  49. ret[i] = value[i] * c;
  50. }
  51. }
  52. function weightedSum(ret, w1, v1, w2, v2) {
  53. for (var j = 0; j < ret.length; ++j) {
  54. ret[j] = w1 * v1[j] + w2 * v2[j];
  55. }
  56. }
  57. /** minimizes a function using the downhill simplex method */
  58. function nelderMead(f, x0, parameters) {
  59. parameters = parameters || {};
  60. var maxIterations = parameters.maxIterations || x0.length * 200,
  61. nonZeroDelta = parameters.nonZeroDelta || 1.05,
  62. zeroDelta = parameters.zeroDelta || 0.001,
  63. minErrorDelta = parameters.minErrorDelta || 1e-6,
  64. minTolerance = parameters.minErrorDelta || 1e-5,
  65. rho = (parameters.rho !== undefined) ? parameters.rho : 1,
  66. chi = (parameters.chi !== undefined) ? parameters.chi : 2,
  67. psi = (parameters.psi !== undefined) ? parameters.psi : -0.5,
  68. sigma = (parameters.sigma !== undefined) ? parameters.sigma : 0.5,
  69. maxDiff;
  70. // initialize simplex.
  71. var N = x0.length,
  72. simplex = new Array(N + 1);
  73. simplex[0] = x0;
  74. simplex[0].fx = f(x0);
  75. simplex[0].id = 0;
  76. for (var i = 0; i < N; ++i) {
  77. var point = x0.slice();
  78. point[i] = point[i] ? point[i] * nonZeroDelta : zeroDelta;
  79. simplex[i+1] = point;
  80. simplex[i+1].fx = f(point);
  81. simplex[i+1].id = i+1;
  82. }
  83. function updateSimplex(value) {
  84. for (var i = 0; i < value.length; i++) {
  85. simplex[N][i] = value[i];
  86. }
  87. simplex[N].fx = value.fx;
  88. }
  89. var sortOrder = function(a, b) { return a.fx - b.fx; };
  90. var centroid = x0.slice(),
  91. reflected = x0.slice(),
  92. contracted = x0.slice(),
  93. expanded = x0.slice();
  94. for (var iteration = 0; iteration < maxIterations; ++iteration) {
  95. simplex.sort(sortOrder);
  96. if (parameters.history) {
  97. // copy the simplex (since later iterations will mutate) and
  98. // sort it to have a consistent order between iterations
  99. var sortedSimplex = simplex.map(function (x) {
  100. var state = x.slice();
  101. state.fx = x.fx;
  102. state.id = x.id;
  103. return state;
  104. });
  105. sortedSimplex.sort(function(a,b) { return a.id - b.id; });
  106. parameters.history.push({x: simplex[0].slice(),
  107. fx: simplex[0].fx,
  108. simplex: sortedSimplex});
  109. }
  110. maxDiff = 0;
  111. for (i = 0; i < N; ++i) {
  112. maxDiff = Math.max(maxDiff, Math.abs(simplex[0][i] - simplex[1][i]));
  113. }
  114. if ((Math.abs(simplex[0].fx - simplex[N].fx) < minErrorDelta) &&
  115. (maxDiff < minTolerance)) {
  116. break;
  117. }
  118. // compute the centroid of all but the worst point in the simplex
  119. for (i = 0; i < N; ++i) {
  120. centroid[i] = 0;
  121. for (var j = 0; j < N; ++j) {
  122. centroid[i] += simplex[j][i];
  123. }
  124. centroid[i] /= N;
  125. }
  126. // reflect the worst point past the centroid and compute loss at reflected
  127. // point
  128. var worst = simplex[N];
  129. weightedSum(reflected, 1+rho, centroid, -rho, worst);
  130. reflected.fx = f(reflected);
  131. // if the reflected point is the best seen, then possibly expand
  132. if (reflected.fx < simplex[0].fx) {
  133. weightedSum(expanded, 1+chi, centroid, -chi, worst);
  134. expanded.fx = f(expanded);
  135. if (expanded.fx < reflected.fx) {
  136. updateSimplex(expanded);
  137. } else {
  138. updateSimplex(reflected);
  139. }
  140. }
  141. // if the reflected point is worse than the second worst, we need to
  142. // contract
  143. else if (reflected.fx >= simplex[N-1].fx) {
  144. var shouldReduce = false;
  145. if (reflected.fx > worst.fx) {
  146. // do an inside contraction
  147. weightedSum(contracted, 1+psi, centroid, -psi, worst);
  148. contracted.fx = f(contracted);
  149. if (contracted.fx < worst.fx) {
  150. updateSimplex(contracted);
  151. } else {
  152. shouldReduce = true;
  153. }
  154. } else {
  155. // do an outside contraction
  156. weightedSum(contracted, 1-psi * rho, centroid, psi*rho, worst);
  157. contracted.fx = f(contracted);
  158. if (contracted.fx < reflected.fx) {
  159. updateSimplex(contracted);
  160. } else {
  161. shouldReduce = true;
  162. }
  163. }
  164. if (shouldReduce) {
  165. // if we don't contract here, we're done
  166. if (sigma >= 1) break;
  167. // do a reduction
  168. for (i = 1; i < simplex.length; ++i) {
  169. weightedSum(simplex[i], 1 - sigma, simplex[0], sigma, simplex[i]);
  170. simplex[i].fx = f(simplex[i]);
  171. }
  172. }
  173. } else {
  174. updateSimplex(reflected);
  175. }
  176. }
  177. simplex.sort(sortOrder);
  178. return {fx : simplex[0].fx,
  179. x : simplex[0]};
  180. }
  181. /// searches along line 'pk' for a point that satifies the wolfe conditions
  182. /// See 'Numerical Optimization' by Nocedal and Wright p59-60
  183. /// f : objective function
  184. /// pk : search direction
  185. /// current: object containing current gradient/loss
  186. /// next: output: contains next gradient/loss
  187. /// returns a: step size taken
  188. function wolfeLineSearch(f, pk, current, next, a, c1, c2) {
  189. var phi0 = current.fx, phiPrime0 = dot(current.fxprime, pk),
  190. phi = phi0, phi_old = phi0,
  191. phiPrime = phiPrime0,
  192. a0 = 0;
  193. a = a || 1;
  194. c1 = c1 || 1e-6;
  195. c2 = c2 || 0.1;
  196. function zoom(a_lo, a_high, phi_lo) {
  197. for (var iteration = 0; iteration < 16; ++iteration) {
  198. a = (a_lo + a_high)/2;
  199. weightedSum(next.x, 1.0, current.x, a, pk);
  200. phi = next.fx = f(next.x, next.fxprime);
  201. phiPrime = dot(next.fxprime, pk);
  202. if ((phi > (phi0 + c1 * a * phiPrime0)) ||
  203. (phi >= phi_lo)) {
  204. a_high = a;
  205. } else {
  206. if (Math.abs(phiPrime) <= -c2 * phiPrime0) {
  207. return a;
  208. }
  209. if (phiPrime * (a_high - a_lo) >=0) {
  210. a_high = a_lo;
  211. }
  212. a_lo = a;
  213. phi_lo = phi;
  214. }
  215. }
  216. return 0;
  217. }
  218. for (var iteration = 0; iteration < 10; ++iteration) {
  219. weightedSum(next.x, 1.0, current.x, a, pk);
  220. phi = next.fx = f(next.x, next.fxprime);
  221. phiPrime = dot(next.fxprime, pk);
  222. if ((phi > (phi0 + c1 * a * phiPrime0)) ||
  223. (iteration && (phi >= phi_old))) {
  224. return zoom(a0, a, phi_old);
  225. }
  226. if (Math.abs(phiPrime) <= -c2 * phiPrime0) {
  227. return a;
  228. }
  229. if (phiPrime >= 0 ) {
  230. return zoom(a, a0, phi);
  231. }
  232. phi_old = phi;
  233. a0 = a;
  234. a *= 2;
  235. }
  236. return a;
  237. }
  238. function conjugateGradient(f, initial, params) {
  239. // allocate all memory up front here, keep out of the loop for perfomance
  240. // reasons
  241. var current = {x: initial.slice(), fx: 0, fxprime: initial.slice()},
  242. next = {x: initial.slice(), fx: 0, fxprime: initial.slice()},
  243. yk = initial.slice(),
  244. pk, temp,
  245. a = 1,
  246. maxIterations;
  247. params = params || {};
  248. maxIterations = params.maxIterations || initial.length * 20;
  249. current.fx = f(current.x, current.fxprime);
  250. pk = current.fxprime.slice();
  251. scale(pk, current.fxprime,-1);
  252. for (var i = 0; i < maxIterations; ++i) {
  253. a = wolfeLineSearch(f, pk, current, next, a);
  254. // todo: history in wrong spot?
  255. if (params.history) {
  256. params.history.push({x: current.x.slice(),
  257. fx: current.fx,
  258. fxprime: current.fxprime.slice(),
  259. alpha: a});
  260. }
  261. if (!a) {
  262. // faiiled to find point that satifies wolfe conditions.
  263. // reset direction for next iteration
  264. scale(pk, current.fxprime, -1);
  265. } else {
  266. // update direction using Polak–Ribiere CG method
  267. weightedSum(yk, 1, next.fxprime, -1, current.fxprime);
  268. var delta_k = dot(current.fxprime, current.fxprime),
  269. beta_k = Math.max(0, dot(yk, next.fxprime) / delta_k);
  270. weightedSum(pk, beta_k, pk, -1, next.fxprime);
  271. temp = current;
  272. current = next;
  273. next = temp;
  274. }
  275. if (norm2(current.fxprime) <= 1e-5) {
  276. break;
  277. }
  278. }
  279. if (params.history) {
  280. params.history.push({x: current.x.slice(),
  281. fx: current.fx,
  282. fxprime: current.fxprime.slice(),
  283. alpha: a});
  284. }
  285. return current;
  286. }
  287. function gradientDescent(f, initial, params) {
  288. params = params || {};
  289. var maxIterations = params.maxIterations || initial.length * 100,
  290. learnRate = params.learnRate || 0.001,
  291. current = {x: initial.slice(), fx: 0, fxprime: initial.slice()};
  292. for (var i = 0; i < maxIterations; ++i) {
  293. current.fx = f(current.x, current.fxprime);
  294. if (params.history) {
  295. params.history.push({x: current.x.slice(),
  296. fx: current.fx,
  297. fxprime: current.fxprime.slice()});
  298. }
  299. weightedSum(current.x, 1, current.x, -learnRate, current.fxprime);
  300. if (norm2(current.fxprime) <= 1e-5) {
  301. break;
  302. }
  303. }
  304. return current;
  305. }
  306. function gradientDescentLineSearch(f, initial, params) {
  307. params = params || {};
  308. var current = {x: initial.slice(), fx: 0, fxprime: initial.slice()},
  309. next = {x: initial.slice(), fx: 0, fxprime: initial.slice()},
  310. maxIterations = params.maxIterations || initial.length * 100,
  311. learnRate = params.learnRate || 1,
  312. pk = initial.slice(),
  313. c1 = params.c1 || 1e-3,
  314. c2 = params.c2 || 0.1,
  315. temp,
  316. functionCalls = [];
  317. if (params.history) {
  318. // wrap the function call to track linesearch samples
  319. var inner = f;
  320. f = function(x, fxprime) {
  321. functionCalls.push(x.slice());
  322. return inner(x, fxprime);
  323. };
  324. }
  325. current.fx = f(current.x, current.fxprime);
  326. for (var i = 0; i < maxIterations; ++i) {
  327. scale(pk, current.fxprime, -1);
  328. learnRate = wolfeLineSearch(f, pk, current, next, learnRate, c1, c2);
  329. if (params.history) {
  330. params.history.push({x: current.x.slice(),
  331. fx: current.fx,
  332. fxprime: current.fxprime.slice(),
  333. functionCalls: functionCalls,
  334. learnRate: learnRate,
  335. alpha: learnRate});
  336. functionCalls = [];
  337. }
  338. temp = current;
  339. current = next;
  340. next = temp;
  341. if ((learnRate === 0) || (norm2(current.fxprime) < 1e-5)) break;
  342. }
  343. return current;
  344. }
  345. exports.bisect = bisect;
  346. exports.nelderMead = nelderMead;
  347. exports.conjugateGradient = conjugateGradient;
  348. exports.gradientDescent = gradientDescent;
  349. exports.gradientDescentLineSearch = gradientDescentLineSearch;
  350. exports.zeros = zeros;
  351. exports.zerosM = zerosM;
  352. exports.norm2 = norm2;
  353. exports.weightedSum = weightedSum;
  354. exports.scale = scale;
  355. }));