001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math3.optimization.fitting;
019    
020    import java.util.ArrayList;
021    import java.util.List;
022    
023    import org.apache.commons.math3.analysis.DifferentiableMultivariateVectorFunction;
024    import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
025    import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
026    import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
027    import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction;
028    import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
029    import org.apache.commons.math3.optimization.MultivariateDifferentiableVectorOptimizer;
030    import org.apache.commons.math3.optimization.PointVectorValuePair;
031    
032    /** Fitter for parametric univariate real functions y = f(x).
033     * <br/>
034     * When a univariate real function y = f(x) does depend on some
035     * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
036     * this class can be used to find these parameters. It does this
037     * by <em>fitting</em> the curve so it remains very close to a set of
038     * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
039     * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
040     * is done by finding the parameters values that minimizes the objective
041     * function &sum;(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
042     * really a least squares problem.
043     *
044     * @param <T> Function to use for the fit.
045     *
046     * @version $Id: CurveFitter.java 1422230 2012-12-15 12:11:13Z erans $
047     * @deprecated As of 3.1 (to be removed in 4.0).
048     * @since 2.0
049     */
050    @Deprecated
051    public class CurveFitter<T extends ParametricUnivariateFunction> {
052    
053        /** Optimizer to use for the fitting.
054         * @deprecated as of 3.1 replaced by {@link #optimizer}
055         */
056        @Deprecated
057        private final DifferentiableMultivariateVectorOptimizer oldOptimizer;
058    
059        /** Optimizer to use for the fitting. */
060        private final MultivariateDifferentiableVectorOptimizer optimizer;
061    
062        /** Observed points. */
063        private final List<WeightedObservedPoint> observations;
064    
065        /** Simple constructor.
066         * @param optimizer optimizer to use for the fitting
067         * @deprecated as of 3.1 replaced by {@link #CurveFitter(MultivariateDifferentiableVectorOptimizer)}
068         */
069        public CurveFitter(final DifferentiableMultivariateVectorOptimizer optimizer) {
070            this.oldOptimizer = optimizer;
071            this.optimizer    = null;
072            observations      = new ArrayList<WeightedObservedPoint>();
073        }
074    
075        /** Simple constructor.
076         * @param optimizer optimizer to use for the fitting
077         * @since 3.1
078         */
079        public CurveFitter(final MultivariateDifferentiableVectorOptimizer optimizer) {
080            this.oldOptimizer = null;
081            this.optimizer    = optimizer;
082            observations      = new ArrayList<WeightedObservedPoint>();
083        }
084    
085        /** Add an observed (x,y) point to the sample with unit weight.
086         * <p>Calling this method is equivalent to call
087         * {@code addObservedPoint(1.0, x, y)}.</p>
088         * @param x abscissa of the point
089         * @param y observed value of the point at x, after fitting we should
090         * have f(x) as close as possible to this value
091         * @see #addObservedPoint(double, double, double)
092         * @see #addObservedPoint(WeightedObservedPoint)
093         * @see #getObservations()
094         */
095        public void addObservedPoint(double x, double y) {
096            addObservedPoint(1.0, x, y);
097        }
098    
099        /** Add an observed weighted (x,y) point to the sample.
100         * @param weight weight of the observed point in the fit
101         * @param x abscissa of the point
102         * @param y observed value of the point at x, after fitting we should
103         * have f(x) as close as possible to this value
104         * @see #addObservedPoint(double, double)
105         * @see #addObservedPoint(WeightedObservedPoint)
106         * @see #getObservations()
107         */
108        public void addObservedPoint(double weight, double x, double y) {
109            observations.add(new WeightedObservedPoint(weight, x, y));
110        }
111    
112        /** Add an observed weighted (x,y) point to the sample.
113         * @param observed observed point to add
114         * @see #addObservedPoint(double, double)
115         * @see #addObservedPoint(double, double, double)
116         * @see #getObservations()
117         */
118        public void addObservedPoint(WeightedObservedPoint observed) {
119            observations.add(observed);
120        }
121    
122        /** Get the observed points.
123         * @return observed points
124         * @see #addObservedPoint(double, double)
125         * @see #addObservedPoint(double, double, double)
126         * @see #addObservedPoint(WeightedObservedPoint)
127         */
128        public WeightedObservedPoint[] getObservations() {
129            return observations.toArray(new WeightedObservedPoint[observations.size()]);
130        }
131    
132        /**
133         * Remove all observations.
134         */
135        public void clearObservations() {
136            observations.clear();
137        }
138    
139        /**
140         * Fit a curve.
141         * This method compute the coefficients of the curve that best
142         * fit the sample of observed points previously given through calls
143         * to the {@link #addObservedPoint(WeightedObservedPoint)
144         * addObservedPoint} method.
145         *
146         * @param f parametric function to fit.
147         * @param initialGuess first guess of the function parameters.
148         * @return the fitted parameters.
149         * @throws org.apache.commons.math3.exception.DimensionMismatchException
150         * if the start point dimension is wrong.
151         */
152        public double[] fit(T f, final double[] initialGuess) {
153            return fit(Integer.MAX_VALUE, f, initialGuess);
154        }
155    
156        /**
157         * Fit a curve.
158         * This method compute the coefficients of the curve that best
159         * fit the sample of observed points previously given through calls
160         * to the {@link #addObservedPoint(WeightedObservedPoint)
161         * addObservedPoint} method.
162         *
163         * @param f parametric function to fit.
164         * @param initialGuess first guess of the function parameters.
165         * @param maxEval Maximum number of function evaluations.
166         * @return the fitted parameters.
167         * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
168         * if the number of allowed evaluations is exceeded.
169         * @throws org.apache.commons.math3.exception.DimensionMismatchException
170         * if the start point dimension is wrong.
171         * @since 3.0
172         */
173        public double[] fit(int maxEval, T f,
174                            final double[] initialGuess) {
175            // prepare least squares problem
176            double[] target  = new double[observations.size()];
177            double[] weights = new double[observations.size()];
178            int i = 0;
179            for (WeightedObservedPoint point : observations) {
180                target[i]  = point.getY();
181                weights[i] = point.getWeight();
182                ++i;
183            }
184    
185            // perform the fit
186            final PointVectorValuePair optimum;
187            if (optimizer == null) {
188                // to be removed in 4.0
189                optimum = oldOptimizer.optimize(maxEval, new OldTheoreticalValuesFunction(f),
190                                                target, weights, initialGuess);
191            } else {
192                optimum = optimizer.optimize(maxEval, new TheoreticalValuesFunction(f),
193                                             target, weights, initialGuess);
194            }
195    
196            // extract the coefficients
197            return optimum.getPointRef();
198        }
199    
200        /** Vectorial function computing function theoretical values. */
201        @Deprecated
202        private class OldTheoreticalValuesFunction
203            implements DifferentiableMultivariateVectorFunction {
204            /** Function to fit. */
205            private final ParametricUnivariateFunction f;
206    
207            /** Simple constructor.
208             * @param f function to fit.
209             */
210            public OldTheoreticalValuesFunction(final ParametricUnivariateFunction f) {
211                this.f = f;
212            }
213    
214            /** {@inheritDoc} */
215            public MultivariateMatrixFunction jacobian() {
216                return new MultivariateMatrixFunction() {
217                    public double[][] value(double[] point) {
218                        final double[][] jacobian = new double[observations.size()][];
219    
220                        int i = 0;
221                        for (WeightedObservedPoint observed : observations) {
222                            jacobian[i++] = f.gradient(observed.getX(), point);
223                        }
224    
225                        return jacobian;
226                    }
227                };
228            }
229    
230            /** {@inheritDoc} */
231            public double[] value(double[] point) {
232                // compute the residuals
233                final double[] values = new double[observations.size()];
234                int i = 0;
235                for (WeightedObservedPoint observed : observations) {
236                    values[i++] = f.value(observed.getX(), point);
237                }
238    
239                return values;
240            }
241        }
242    
243        /** Vectorial function computing function theoretical values. */
244        private class TheoreticalValuesFunction implements MultivariateDifferentiableVectorFunction {
245    
246            /** Function to fit. */
247            private final ParametricUnivariateFunction f;
248    
249            /** Simple constructor.
250             * @param f function to fit.
251             */
252            public TheoreticalValuesFunction(final ParametricUnivariateFunction f) {
253                this.f = f;
254            }
255    
256            /** {@inheritDoc} */
257            public double[] value(double[] point) {
258                // compute the residuals
259                final double[] values = new double[observations.size()];
260                int i = 0;
261                for (WeightedObservedPoint observed : observations) {
262                    values[i++] = f.value(observed.getX(), point);
263                }
264    
265                return values;
266            }
267    
268            /** {@inheritDoc} */
269            public DerivativeStructure[] value(DerivativeStructure[] point) {
270    
271                // extract parameters
272                final double[] parameters = new double[point.length];
273                for (int k = 0; k < point.length; ++k) {
274                    parameters[k] = point[k].getValue();
275                }
276    
277                // compute the residuals
278                final DerivativeStructure[] values = new DerivativeStructure[observations.size()];
279                int i = 0;
280                for (WeightedObservedPoint observed : observations) {
281    
282                    // build the DerivativeStructure by adding first the value as a constant
283                    // and then adding derivatives
284                    DerivativeStructure vi = new DerivativeStructure(point.length, 1, f.value(observed.getX(), parameters));
285                    for (int k = 0; k < point.length; ++k) {
286                        vi = vi.add(new DerivativeStructure(point.length, 1, k, 0.0));
287                    }
288    
289                    values[i++] = vi;
290    
291                }
292    
293                return values;
294            }
295    
296        }
297    
298    }