画像同士の距離が取れるなら、ビール画像の判定プログラムもできるよね。
ってことで、手持ちの画像をSVMに食わせてみて、ビール画像を判定させてみました。
誤判別?なにそれ
ソースはこんな感じ。
ビールの画像はbeerがついたフォルダかファイル名にしてください。
import java.awt.*; import java.awt.image.BufferedImage; import java.io.*; import java.util.*; import java.util.List; import javax.imageio.ImageIO; import javax.swing.ImageIcon; import javax.swing.text.*; public class ColorCoherenceVecotrLaern extends javax.swing.JFrame { /** Creates new form ColorCoherenceVecotrLaern */ public ColorCoherenceVecotrLaern() { initComponents(); System.out.println("CCV計算"); readFile(new File("C:\\search")); System.out.println("カーネルパラメータ"); solveGaussianParam(); System.out.println("SMO"); learn(); System.out.println("終了"); } @SuppressWarnings("unchecked") private void initComponents() { javax.swing.JScrollPane jScrollPane1 = new javax.swing.JScrollPane(); tpResut = new javax.swing.JTextPane(); javax.swing.JPanel jPanel1 = new javax.swing.JPanel(); txtSearch = new javax.swing.JTextField(); javax.swing.JButton btnJudge = new javax.swing.JButton(); setDefaultCloseOperation(javax.swing.WindowConstants.EXIT_ON_CLOSE); jScrollPane1.setViewportView(tpResut); getContentPane().add(jScrollPane1, java.awt.BorderLayout.CENTER); txtSearch.setColumns(30); jPanel1.add(txtSearch); btnJudge.setText("判定"); btnJudge.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { btnJudgeActionPerformed(evt); } }); jPanel1.add(btnJudge); getContentPane().add(jPanel1, java.awt.BorderLayout.NORTH); pack(); } private void btnJudgeActionPerformed(java.awt.event.ActionEvent evt) { String filename = txtSearch.getText(); File f = new File(filename); try { BufferedImage img = ImageIO.read(f); int[] ccv = colorCoherenceVector(img); int result = trial(ccv); int wd = img.getWidth(); int ht = img.getHeight(); if(wd > ht){ ht = 80 * ht / wd; wd = 80; }else{ wd = 80 * wd / ht; ht = 80; } BufferedImage thumb = new BufferedImage(wd, ht, BufferedImage.TYPE_INT_RGB); Graphics2D g = (Graphics2D) thumb.getGraphics(); g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON); g.drawImage(img, 0, 0, wd, ht, this); g.dispose(); StyledDocument sd = tpResut.getStyledDocument(); SimpleAttributeSet sas = new SimpleAttributeSet(); StyleConstants.setIcon(sas, new ImageIcon(thumb)); sd.insertString(sd.getLength(), "画像", sas); sd.insertString(sd.getLength(), result > 0 ? "ビールだ!\n" : "なにそれ\n", null); }catch(BadLocationException e){ }catch(IOException e){ System.out.println(e.getMessage()); } } //ファイルのCCVを求める private void readFile(File f){ if(f.isDirectory()){ File[] files = f.listFiles(); for(File file : files){ readFile(file); } }else{ try { BufferedImage img = ImageIO.read(f); if(img == null) return; int[] data = colorCoherenceVector(img); int i = -1; if(f.getCanonicalPath().contains("beer")){ i = 1; } patterns.add(new AbstractMap.SimpleEntry<Integer, int[]>(i, data)); } catch (IOException ex) { System.out.println(ex.getMessage()); } } } /** CCVを求める */ public static int[] colorCoherenceVector(BufferedImage imgsrc){ int w = imgsrc.getWidth(); int h = imgsrc.getHeight(); //サイズ正規化 int limit = 200; if(w < h){ w = w * limit / h; h = limit; }else{ h = h * limit / w; w = limit; } BufferedImage img = new BufferedImage(w, h, BufferedImage.TYPE_INT_RGB); Graphics2D grp = (Graphics2D) img.getGraphics(); grp.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON); grp.drawImage(imgsrc, 0, 0, w, h, null); grp.dispose(); //ガウシアンフィルタ int[] ctemp = img.getRGB(0, 0, w, h, null, 0, w); int[] ctbl = new int[ctemp.length]; int[][] filter = { {1, 2, 1}, {2, 4, 2}, {1, 2, 1}}; for(int y = 0; y < h; ++y){ for(int x = 0; x < w; ++x){ int tr = 0; int tg = 0; int tb = 0; int t = 0; for(int i = -1; i < 2; ++i){ for(int j = -1; j < 2; ++j){ if(y + i < 0) continue; if(x + j < 0) continue; if(y + i >= h) continue; if(x + j >= w) continue; t += filter[i + 1][j + 1]; int adr = (x + j) + (y + i) * w; tr += filter[i + 1][j + 1] * ((ctemp[adr] >> 16) & 255); tg += filter[i + 1][j + 1] * ((ctemp[adr] >> 8) & 255); tb += filter[i + 1][j + 1] * ( ctemp[adr] & 255); } } ctbl[x + y * w] = ((tr / t) << 16) + ((tg / t) << 8) + tb / t; } } //減色 for(int i = 0; i < ctbl.length; ++i){ int r = (ctemp[i] >> 16) & 192; int g = (ctemp[i] >> 8) & 192; int b = ctemp[i] & 192; ctbl[i] = (r << 16) + (g << 8) + b; } //タグ付け int[][] lbl = new int[w][h]; int id = 0; for(int y = 0; y < h; ++y){ for(int x = 0; x < w; ++x){ int col = ctbl[y * w + x]; if(y > 0){ if(x > 0){ if(ctbl[(y - 1) * w + x - 1] == col){ //左上と一緒 lbl[x][y] = lbl[x - 1][y - 1]; continue; } } if(ctbl[(y - 1) * w + x] == col){ //上と一緒 lbl[x][y] = lbl[x][y - 1]; continue; } if(x < w - 1){ if(ctbl[(y - 1) * w + x + 1] == col){ //右上と一緒 lbl[x][y] = lbl[x + 1][y - 1]; continue; } } } if(x > 0){ if(ctbl[y * w + x - 1] == col){ //左と一緒 lbl[x][y] = lbl[x - 1][y]; continue; } } lbl[x][y] = id; ++id; } } //集計 int[] count = new int[id]; int[] color = new int[id]; for(int x = 0; x < w; ++x){ for(int y = 0; y < h; ++y){ count[lbl[x][y]]++; color[lbl[x][y]] = ctbl[y * w + x]; } } int[] data = new int[129]; for(int i = 0; i < id; ++i){ int d = color[i]; color[i] = (((d >> 22) & 3) << 4) + (((d >> 14) & 3) << 2) + ((d >> 6) & 3); if(count[i] < 20){ data[color[i] * 2 + 1] ++; }else{ data[color[i] * 2] ++; } } return data; } double sig = 35; /** カーネルパラメータを求める */ void solveGaussianParam(){ double s = 0; int count = 0; for(Map.Entry<Integer, int[]> p1 : patterns){ double m = Double.MAX_VALUE; for(Map.Entry<Integer, int[]> p2 : patterns){ double d = 0; for(int i = 0; i < p1.getValue().length; ++i){ int t = p1.getValue()[i] - p2.getValue()[i]; d += t * t; } if(d == 0) continue; if(d < m) m = d; } if(m == Double.MAX_VALUE) continue; s += m; ++count; } sig = Math.sqrt(s / count) * 2; System.out.println(sig); } double kernel(int[] x1, int[] x2){ //ガウシアンカーネル int total = 0; for(int i = 0; i < x1.length; ++i){ total += (x1[i] - x2[i]) * (x1[i] - x2[i]); } return Math.exp(-total / (sig * sig)); } double[] w;//係数 double b;//バイアス final double c = 10;//許容範囲?無限大にするとハードマージンになるはずだけど final double tol = 0.7;//KKT条件の許容範囲(1 - ε) double[] lambda; double z = 0; List<Map.Entry<Integer, int[]>> patterns = new ArrayList<Map.Entry<Integer, int[]>>(); public void learn() { w = new double[patterns.size()]; b = 0; lambda = new double[patterns.size()]; for(int i = 0; i < lambda.length; ++i){ lambda[i] = 0; } //未定乗数を求める boolean alldata = true;//すべてのデータを処理する場合 boolean changed = false;//変更があった eCache = new double[patterns.size()]; int lp; for(lp = 0; lp < 500000 && (alldata || changed); ++lp) { changed = false; z = 0; boolean lastchange = true; PROC_LOOP: for(int j = 0; j < patterns.size(); ++j){ //基準点2を選ぶ double alpha2 = lambda[j]; if(!alldata && (alpha2 <= 0 || alpha2 >= c)){// 0 < α < C の点だけ処理する continue; } if(lastchange){ //初回やデータがかわったときキャッシュのクリア for(int i = 0; i < eCache.length; ++i) eCache[i] = Double.NaN; } lastchange = false; int t2 = patterns.get(j).getKey(); double fx2 = calcE(j); //KKT条件の判定 double r2 = fx2 * t2; if(!((alpha2 < c && r2 < -tol) || (alpha2 > 0 && r2 > tol))){//KKT条件をみたすなら処理しない continue; } //基準点1を選ぶ //選択法1 int i = 0; int offset = (int)(Math.random() * patterns.size()); double max = -1; for(int ll = 0; ll < patterns.size(); ++ll){//全データにつき int l = (ll + offset) % patterns.size(); //0 < α < C if(0 >= lambda[l] || c <= lambda[l]) continue; double dif = Math.abs(calcE(l) - fx2); if(dif > max){ max = dif; i = l; } } if(max >= 0){ if(step(i, j)){ //処理をしたら次へ changed = true; lastchange = true; continue PROC_LOOP; } } //選択法2 offset = (int)(Math.random() * patterns.size());//ランダムな位置から for(int l = 0; l < patterns.size(); ++l){ //0 < α < C i = (l + offset) % patterns.size(); if(0 >= lambda[i] || c <= lambda[i]) continue; if(step(i, j)){ //処理をしたら次へ changed = true; lastchange = true; continue PROC_LOOP; } } //選択法3 offset = (int)(Math.random() * patterns.size());//ランダムな位置から for(int l = 0; l < patterns.size(); ++l){ i = (l + offset) % patterns.size(); if(step(i, j)){ //処理をしたら次へ changed = true; lastchange = true; continue PROC_LOOP; } } } ////すべてのデータを処理しても処理するものがなければ終了になる if(z < 0.01) changed = false; if(alldata){ alldata = false; }else{ if(changed) alldata = true; } } //System.out.println(lp); //wの値を求める for(int i = 0; i < w.length; ++i){ w[i] = lambda[i] * patterns.get(i).getKey(); } //bを求める b = 0; int count = 0; for(int i = 0; i < lambda.length; ++i){ if(Math.abs(w[i]) <= 0.05) continue; for(int l = 0; l < w.length; ++l){ b -= w[l] * kernel( patterns.get(i).getValue(), patterns.get(l).getValue()); } ++count; } b /= count; } public int trial(int[] data) { double s = b; for(int i = 0; i < w.length; ++i){ Map.Entry<Integer, int[]> p = patterns.get(i); s += w[i] * kernel(data, p.getValue()); } return s > 0 ? 1 : -1; } private double[] eCache; private double calcE(int i){ if(!Double.isNaN(eCache[i])) return eCache[i]; double e = b - patterns.get(i).getKey(); for(int j = 0; j < lambda.length; ++j){ e += lambda[j] * patterns.get(j).getKey() * kernel(patterns.get(j).getValue(), patterns.get(i).getValue()); } eCache[i] = e; return e; } /** 実際の計算処理 */ private boolean step(int i, int j) { if(i == j) return false; double fx2 = calcE(j); int t1 = patterns.get(i).getKey(); int t2 = patterns.get(j).getKey(); double fx1 = calcE(i); //基準点2を計算 double k11 = kernel(patterns.get(i).getValue(), patterns.get(i).getValue()); double k22 = kernel(patterns.get(j).getValue(), patterns.get(j).getValue()); double k12 = kernel(patterns.get(i).getValue(), patterns.get(j).getValue()); double eta = k11 + k22 - 2 * k12; if(eta <= 0) return false; double newwj = lambda[j] + t2 * (fx1 - fx2) / eta; //クリッピング double u; double v; if(t1 == t2){ u = Math.max(0, lambda[j] + lambda[i] - c); v = Math.min(c, lambda[j] + lambda[i]); }else{ u = Math.max(0, lambda[j] - lambda[i]); v = Math.min(c, c + lambda[j] - lambda[i]); } if(u == v) return false; newwj = Math.max(u, newwj); newwj = Math.min(v, newwj); //基準点2から基準点1を計算 z += Math.abs(lambda[j] - newwj); lambda[i] += t1 * t2 * (lambda[j] - newwj); lambda[j] = newwj; return true; } /** * @param args the command line arguments */ public static void main(String args[]) { java.awt.EventQueue.invokeLater(new Runnable() { public void run() { new ColorCoherenceVecotrLaern().setVisible(true); } }); } // Variables declaration - do not modify private javax.swing.JTextPane tpResut; private javax.swing.JTextField txtSearch; // End of variables declaration }