最近在写一个荧光图像分析软件,需要自己拟合方程。一元回归线公式的算法参考了《Java数值方法》,拟合度R^2(绝对系数)是自己写的,欢迎讨论。计算结果和Excel完全一致。 总共三个文件: DataPoint.java /**  * A data point for interpolation and regression.  */ public class DataPoint {     /** the x value */  public float x;     /** the y value */  public float y;     /**      * Constructor.      * @param x the x value      * @param y the y value      */     public DataPoint(float x, float y)     {         this.x = x;         this.y = y;     } } /**  * A least-squares regression line function.  */ import java.util.*; import java.math.BigDecimal; public class RegressionLine   //implements Evaluatable {     /** sum of x */     private double sumX;     /** sum of y */     private double sumY;     /** sum of x*x */   private double sumXX;     /** sum of x*y */   private double sumXY;     /** sum of y*y */   private double sumYY;     /** sum of yi-y */   private double sumDeltaY;     /** sum of sumDeltaY^2 */   private double sumDeltaY2;     /**误差 */     private double sse;       private double sst;       private double E;     private String[] xy ;          private ArrayList listX ;     private ArrayList listY ;          private int XMin,XMax,YMin,YMax;          /** line coefficient a0 */  private float a0;     /** line coefficient a1 */  private float a1;     /** number of data points */        private int     pn ;     /** true if coefficients valid */   private boolean coefsValid;     /**      * Constructor.      */     public RegressionLine() {      XMax = 0;      YMax = 0;      pn = 0;      xy =new String[2];      listX = new ArrayList();      listY = new ArrayList();     }     /**      * Constructor.      * @param data the array of data points      */     public RegressionLine(DataPoint data[])     {       pn = 0;      xy =new String[2];      listX = new ArrayList();      listY = new ArrayList();         for (int i = 0; i < data.length; ++i) {             addDataPoint(data[i]);         }     }     /**      * Return the current number of data points.      * @return the count      */     public int getDataPointCount() { return pn; }     /**      * Return the coefficient a0.      * @return the value of a0      */     public float getA0()     {         validateCoefficients();         return a0;     }     /**      * Return the coefficient a1.      * @return the value of a1      */     public float getA1()     {         validateCoefficients();         return a1;     }     /**      * Return the sum of the x values.      * @return the sum      */     public double getSumX() { return sumX; }     /**      * Return the sum of the y values.      * @return the sum      */     public double getSumY() { return sumY; }     /**      * Return the sum of the x*x values.      * @return the sum      */     public double getSumXX() { return sumXX; }     /**      * Return the sum of the x*y values.      * @return the sum      */     public double getSumXY() { return sumXY; }          public double getSumYY() { return sumYY; }          public int getXMin() {   return XMin;  }  public int getXMax() {   return XMax;  }  public int getYMin() {   return YMin;  }  public int getYMax() {   return YMax;  }          /**      * Add a new data point: Update the sums.      * @param dataPoint the new data point      */     public void addDataPoint(DataPoint dataPoint)     {         sumX  += dataPoint.x;         sumY  += dataPoint.y;         sumXX += dataPoint.x*dataPoint.x;         sumXY += dataPoint.x*dataPoint.y;         sumYY += dataPoint.y*dataPoint.y;                  if(dataPoint.x > XMax){          XMax = (int)dataPoint.x;         }         if(dataPoint.y > YMax){          YMax = (int)dataPoint.y;         }                  //把每个点的具体坐标存入ArrayList中,备用                  xy[0] = (int)dataPoint.x+ "";         xy[1] = (int)dataPoint.y+ "";         if(dataPoint.x!=0 && dataPoint.y != 0){         System.out.print(xy[0]+",");         System.out.println(xy[1]);                          try{         //System.out.println("n:"+n);         listX.add(pn,xy[0]);         listY.add(pn,xy[1]);         }         catch(Exception e){          e.printStackTrace();         }                                  /*         System.out.println("N:" + n);         System.out.println("ArrayList listX:"+ listX.get(n));         System.out.println("ArrayList listY:"+ listY.get(n));         */         }                 ++pn;         coefsValid = false;      }     /**      * Return the value of the regression line function at x.      * (Implementation of Evaluatable.)      * @param x the value of x      * @return the value of the function at x      */     public float at(int x)     {         if (pn < 2) return Float.NaN;         validateCoefficients();         return a0 + a1*x;     }          public float at(float x)     {         if (pn < 2) return Float.NaN;         validateCoefficients();         return a0 + a1*x;     }     /**      * Reset.      */     public void reset()     {         pn = 0;         sumX = sumY = sumXX = sumXY = 0;         coefsValid = false;     }     /**      * Validate the coefficients.      * 计算方程系数 y=ax+b 中的a      */     private void validateCoefficients()     {         if (coefsValid) return;         if (pn >= 2) {             float xBar = (float) sumX/pn;             float yBar = (float) sumY/pn;             a1 = (float) ((pn*sumXY - sumX*sumY)                             /(pn*sumXX - sumX*sumX));             a0 = (float) (yBar - a1*xBar);         }         else {             a0 = a1 = Float.NaN;         }         coefsValid = true;     }          /**      * 返回误差      */     public double getR(){         //遍历这个list并计算分母      for(int i = 0; i < pn -1; i++)    {                float Yi= (float)Integer.parseInt(listY.get(i).toString());       float Y = at(Integer.parseInt(listX.get(i).toString()));        float deltaY = Yi - Y;           float deltaY2 = deltaY*deltaY;       /*       System.out.println("Yi:" + Yi);       System.out.println("Y:" + Y);       System.out.println("deltaY:" + deltaY);       System.out.println("deltaY2:" + deltaY2);       */                     sumDeltaY2 += deltaY2;          //System.out.println("sumDeltaY2:" + sumDeltaY2);                }                  sst = sumYY - (sumY*sumY)/pn;              //System.out.println("sst:" + sst);      E =1- sumDeltaY2/sst;                  return round(E,4) ;     }          //用于实现精确的四舍五入     public double round(double v,int scale){      if(scale<0){      throw new IllegalArgumentException(      "The scale must be a positive integer or zero");      }            BigDecimal b = new BigDecimal(Double.toString(v));      BigDecimal one = new BigDecimal("1");      return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).doubleValue();     }              public  float round(float v,int scale){      if(scale<0){      throw new IllegalArgumentException(      "The scale must be a positive integer or zero");      }            BigDecimal b = new BigDecimal(Double.toString(v));      BigDecimal one = new BigDecimal("1");      return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).floatValue();     }     } 演示程序: LinearRegression.java /**  * <p><b>Linear Regression</b>  * <br>   * Demonstrate linear regression by constructing the regression line for a set  * of data points.  *   * <p>require DataPoint.java,RegressionLine.java   *   * <p>为了计算对于给定数据点的最小方差回线,需要计算SumX,SumY,SumXX,SumXY; (注:SumXX = Sum (X^2))  * <p><b>回归直线方程如下: f(x)=a1x+a0   </b>  * <p><b>斜率和截距的计算公式如下:</b>  * <br>n: 数据点个数  * <p>a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2)  * <br>a0=(SumY - SumY * a1)/n   * <br>(也可表达为a0=averageY-a1*averageX)  *   * <p><b>画线的原理:两点成一直线,只要能确定两个点即可</b><br>  *  第一点:(0,a0) 再随意取一个x1值代入方程,取得y1,连结(0,a0)和(x1,y1)两点即可。  * 为了让线穿过整个图,x1可以取横坐标的最大值Xmax,即两点为(0,a0),(Xmax,Y)。如果y=a1*Xmax+a0,y大于  * 纵坐标最大值Ymax,则不用这个点。改用y取最大值Ymax,算得此时x的值,使用(X,Ymax), 即两点为(0,a0),(X,Ymax)  *   * <p><b>拟合度计算:(即Excel中的R^2)</b>  * <p> *R2 = 1 - E  * <p>误差E的计算:E = SSE/SST  * <p>SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n;  * <p>   */ public class LinearRegression {     private static final int MAX_POINTS = 10;     private double E;     /**   * Main program.   *    * @param args   *            the array of runtime arguments   */     public static void main(String args[])     {         RegressionLine line = new RegressionLine();         line.addDataPoint(new DataPoint(20, 136));         line.addDataPoint(new DataPoint(40, 143));         line.addDataPoint(new DataPoint(60, 152));         line.addDataPoint(new DataPoint(80, 162));         line.addDataPoint(new DataPoint(100, 167));                  printSums(line);         printLine(line);     }     /**   * Print the computed sums.   *    * @param line   *            the regression line   */     private static void printSums(RegressionLine line)     {         System.out.println("\n数据点个数 n = " + line.getDataPointCount());         System.out.println("\nSum x  = " + line.getSumX());         System.out.println("Sum y  = " + line.getSumY());         System.out.println("Sum xx = " + line.getSumXX());         System.out.println("Sum xy = " + line.getSumXY());         System.out.println("Sum yy = " + line.getSumYY());                     }     /**   * Print the regression line function.   *    * @param line   *            the regression line   */     private static void printLine(RegressionLine line)     {         System.out.println("\n回归线公式:  y = " +                            line.getA1() +                            "x + " + line.getA0());         System.out.println("拟合度:     R^2 = " + line.getR());     }       }  
 
  |