gradientDescent.js 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import {dot, norm2, scale, zeros, weightedSum} from "./blas1";
  2. import {wolfeLineSearch} from "./linesearch";
  3. export function gradientDescent(f, initial, params) {
  4. params = params || {};
  5. var maxIterations = params.maxIterations || initial.length * 100,
  6. learnRate = params.learnRate || 0.001,
  7. current = {x: initial.slice(), fx: 0, fxprime: initial.slice()};
  8. for (var i = 0; i < maxIterations; ++i) {
  9. current.fx = f(current.x, current.fxprime);
  10. if (params.history) {
  11. params.history.push({x: current.x.slice(),
  12. fx: current.fx,
  13. fxprime: current.fxprime.slice()});
  14. }
  15. weightedSum(current.x, 1, current.x, -learnRate, current.fxprime);
  16. if (norm2(current.fxprime) <= 1e-5) {
  17. break;
  18. }
  19. }
  20. return current;
  21. }
  22. export function gradientDescentLineSearch(f, initial, params) {
  23. params = params || {};
  24. var current = {x: initial.slice(), fx: 0, fxprime: initial.slice()},
  25. next = {x: initial.slice(), fx: 0, fxprime: initial.slice()},
  26. maxIterations = params.maxIterations || initial.length * 100,
  27. learnRate = params.learnRate || 1,
  28. pk = initial.slice(),
  29. c1 = params.c1 || 1e-3,
  30. c2 = params.c2 || 0.1,
  31. temp,
  32. functionCalls = [];
  33. if (params.history) {
  34. // wrap the function call to track linesearch samples
  35. var inner = f;
  36. f = function(x, fxprime) {
  37. functionCalls.push(x.slice());
  38. return inner(x, fxprime);
  39. };
  40. }
  41. current.fx = f(current.x, current.fxprime);
  42. for (var i = 0; i < maxIterations; ++i) {
  43. scale(pk, current.fxprime, -1);
  44. learnRate = wolfeLineSearch(f, pk, current, next, learnRate, c1, c2);
  45. if (params.history) {
  46. params.history.push({x: current.x.slice(),
  47. fx: current.fx,
  48. fxprime: current.fxprime.slice(),
  49. functionCalls: functionCalls,
  50. learnRate: learnRate,
  51. alpha: learnRate});
  52. functionCalls = [];
  53. }
  54. temp = current;
  55. current = next;
  56. next = temp;
  57. if ((learnRate === 0) || (norm2(current.fxprime) < 1e-5)) break;
  58. }
  59. return current;
  60. }