| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- import {dot, norm2, scale, zeros, weightedSum} from "./blas1";
- import {wolfeLineSearch} from "./linesearch";
- export function gradientDescent(f, initial, params) {
- params = params || {};
- var maxIterations = params.maxIterations || initial.length * 100,
- learnRate = params.learnRate || 0.001,
- current = {x: initial.slice(), fx: 0, fxprime: initial.slice()};
- for (var i = 0; i < maxIterations; ++i) {
- current.fx = f(current.x, current.fxprime);
- if (params.history) {
- params.history.push({x: current.x.slice(),
- fx: current.fx,
- fxprime: current.fxprime.slice()});
- }
- weightedSum(current.x, 1, current.x, -learnRate, current.fxprime);
- if (norm2(current.fxprime) <= 1e-5) {
- break;
- }
- }
- return current;
- }
- export function gradientDescentLineSearch(f, initial, params) {
- params = params || {};
- var current = {x: initial.slice(), fx: 0, fxprime: initial.slice()},
- next = {x: initial.slice(), fx: 0, fxprime: initial.slice()},
- maxIterations = params.maxIterations || initial.length * 100,
- learnRate = params.learnRate || 1,
- pk = initial.slice(),
- c1 = params.c1 || 1e-3,
- c2 = params.c2 || 0.1,
- temp,
- functionCalls = [];
- if (params.history) {
- // wrap the function call to track linesearch samples
- var inner = f;
- f = function(x, fxprime) {
- functionCalls.push(x.slice());
- return inner(x, fxprime);
- };
- }
- current.fx = f(current.x, current.fxprime);
- for (var i = 0; i < maxIterations; ++i) {
- scale(pk, current.fxprime, -1);
- learnRate = wolfeLineSearch(f, pk, current, next, learnRate, c1, c2);
- if (params.history) {
- params.history.push({x: current.x.slice(),
- fx: current.fx,
- fxprime: current.fxprime.slice(),
- functionCalls: functionCalls,
- learnRate: learnRate,
- alpha: learnRate});
- functionCalls = [];
- }
- temp = current;
- current = next;
- next = temp;
- if ((learnRate === 0) || (norm2(current.fxprime) < 1e-5)) break;
- }
- return current;
- }
|