Mahout中關於MultiLayer Perceptron模塊的源碼解析

Mahout中關於MultiLayer Perceptron模塊的源碼解析

前段時間學習NN時使用到了BPNN,考慮到模型的分佈式擴展,我想到使用Mahout的MultiLayer Perceptron(mlp)來實現。於是下載研讀了Mahout中該模塊的源碼

,這會兒希望能把學習筆記記錄下來,一來怕自己後面遺忘,二來與大夥兒一同學習。

這裏我使用的Mahout版本是0.10,直接因爲Apache貌似在Mahout0.11版本中刪去了mlp板塊(反正我是沒找到。。。。)

模塊路徑:mr.src.main.java.org.apache.mahout.classifer.mlp

該模塊路徑下存放有5個.java文件,分別是:

</pre>        <img src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAO8AAAB6CAYAAABX2wRJAAAcvklEQVR4Ae1da2xcx3U+FKnYlhogQYo0Bdq0sR6mjHUKGEisbaQYiR8ypZBRFNOoUcAJ3S4s2vICpvwk8qMFHNq1KqpeyRblhSX3T92IjiGQkJZSHdcW1dKWXRRNaImWSBd9BWnQHxUs2bW1FHvOPO6d+37sLneXe669vPfOnTnnzDdz7pmZ3fnUVi6XF4APRoARaDoEljWdxWwwI8AICATYebkjMAJNigA7b5M2HJvNCLDzch9gBJoUAXbeJm04NpsRYOflPsAINCkC7LxN2nBsNiPQEQXB/JUFeP1ffgVn/v0C/O+lT6Oy83NGgBFYJATaon6kMfb2f8JVy5fB+s4vwudWLl8ks1gNI8AIRCEQGXnfw4j78Peuh+UdbbCwYP8Y69SpU1GyxfP169dDR0ekmliyOBMjwAjYCER61aeX5+EzGHlNx9XFN3zjGwBt+H8b/sFDn69cuQLLli2Dd9/9Jzh58iTctP4mWLlipS7GZ0aAEagCAqkXrMiZ6T/xv4rI2sHJcem4fPkybNiwAU6/fRouXrxYBXNZBCPACGgEUjsvRdcwByYFly5dEpF3fn4eTp8+rXXW8FyC/o5+KPlpmC3Axo6NUJilhyH5/MpyWvUQcLRD9cS2oqTUziuiLEVf9cELa2itI/Ctt94CX73hBui87roQbGehsLED58XasdxZydHw+cYCCL9zP26Ke11Hqqf89Pu+YZqiMmxkgyCQ2nl15BVOSwNoWstyOTA58Zd++0vw5d/7MlD+sCObBRg95nXP2cIQFMMKBj6TDmM5yeo8TJYnIb86sECNH2RheKYMuLoP5ZlhmO4OelnV2Axf8S6sfPNUKbHu7VClejSAmIqdV0deecYaaQe2F6ZjVTPT2wswsNs15J2FY6MAw8O5WDKaJhN24MHclO/LqmnqwIbWHYHUzkvzWIq2YohMDiuir4zAwpHFSpZ6jtWk/KHHGurQRRiSk1KZtbQbBqAXNq8xSoo5kzGvdd+LrDTU7oSBKYBitxpy++ZTcsUz95BWRqONDnv6jeG7Gs6rYbCdT0exJMP9pLKC8gelu2xyDN2pjAsrnKDQVKa/pORZUxan/A7H+gI9w9FEATHyYBLVftHYzxY22tj7tpeho0UuUzuvHjbryCsirnJgHX3JofV11LCZ8O7qycGUFX2xAw0VITeYh+Qj3S7YX56BYRyK58ZxmDoZJgP17AY4RMNZ+ozn0OHp5bAa8oNoz+gxa65dGtP2UEcdgnV6GIy6ekc7sbPbvabYPQY9JM9PN3a+oWIWejdTzZLKovzdMD08I+1FHZNiLpBAjqOewVg56+DVWx4H6HY48BQMnO2xpgYw0KcWCG1cnFdxsacR2JTqC0FlnJJb4S6x8x54Ywds3dsBB2a74a4XVkLvgZVw58gKeOFkXkRh9/BZToUXIue8AuyunehwKvqKqDsMO7tq3QzopPvzgKFGRoxuY4bd1QO5qVGQU/ESjBVz0EP2lMZwHo4dtVNHDBm5ps/bc/bc+H5wmm7k7xyF3hk1/04qi/Jnh+GQe/KeVI7Cesx44biRdtTBT69HBs7rdYOt3gy9+PIMP2JiP3sMRqcU9vRSDWqvcGVL7mli5y39YgQO7njf8zn+HnZ6Nd91OzAFYIpq0Ycd7QoY5bK9m1NE3WgtjhxqCNYHh6yIYfe5Ltg5LBfSxMJZrsd2SHSgGYpgxkdGQId048ZYsHIvnCWWZYg1L6slx5QZeZ2Fdea0JjK/kSEC+x61LjArFj52SuxDyxiyW+AysfMSJv92+S3Ph9LJaZ0OTHNgMRuWzyhT1CGi3QAMYJQbdEcXq+w06CBHDYtT2/TH+bMwZUQyt7zVm3EhbXQ37KaFMx1VlI27jahV6jfm4UmsSSrLk7+E80yM+J50HCC4bCoaYXa20IdrAjqaxTBYye9zrAGoNYmoeY1wOJ/V9Qjsu8SbU2IvpxhoZ0SZGDVZMllSOe8Hl98E8zN3+Q0BiIy4ymFFFMZkFY0vXLgQEzSKdjRZNaKcWRJXag9hNNRD1r6zGbAjpSMjbMZxm7VgZT4yr2noh8tinWqRxSOPVoYzRRwm48KZ1Ulxnii+7tHD5g4Y63EPk00lYddJZbnzj8Ea8ZJzp3ttysGYtZjUOZCB8bK2eXUMrFA+TnIzA52WDJx6w7jfnD6suuazSOxx6I3IFzOD9ld8UWVM+Uv8OnJX0Z//zc/hz/74D6zISfPdzrs7PLDMvFyGn+Q+FL9vbpM/eJbX6rfPz+8fgR07dnjKNUNCqb8DhtbNqIWhZrDYbSMu8mzshLODZdjvnIi7M9b+nqJw51kYtF4ctVe5VDV4vbCCmtKoGUOt3KxAI2ghizYtLMBHH30k7pruD3a2IRrCl62w23RVaCSDxbQk2wtpp8mNVJd625LYebtu2A6ll0c8dn/h4xvhx0NDnvRmTqCIS4vP9HVTvQNWM+MobC/h979iJT+Hw/Wwr+6avqaLVoHEw+ZFs4wVMQKMQCgCkQtWK65qh0v/F+drnlA9/JARYASqjEBk5D32zn/hZvx2uDnzW2JTfpX1szhGgBFIiUCk835avgKld38JZ/7jAhCrBh+MACPQGAhEOm9jmMlWMAKMgBuBxKvNWgAT0Gkk+MwI1AeB1M5L5jIBXX0ajbUyAoRA5GpzEEz6N8v0Swzxm2bMqM9MQBeEGqczAtVDILXzWvt5hffajqsdmExcfAK6agBD+1ZTbjKohvrUMugnkLSBPkJA0CaBiGL8uPEQSO28wknVpgPzmqqoHbhiAjr6VY7F4lAP8KRD+JLj1d22euDBOhsJgdTOqyMveirGXtpJhNVSzkwVJAemT1wCulwug8QLjckQWZltMSPiYvUK3CVVXyK+xaro0tdTsfNqJ5Vn4bXCaWk0nejokVvzHPtFEwmoYeZGtq2G1WbRjY1AauetOgEd0ZvIjboh8zaaj9p7aC3SNzGPM+apjnsd+VRZPQwXeWxZ4XPFtLaRTie5WwE3O1h2C/4qY57qGIo76+okewuok9nXdP3cFXNggwV0PgcpnZRv20nZmADOhLcRrlM7rx4268grhsxi+CyWn+WwmcKvGj7HIaADsdGeNtAbjmihRJ05nPTNyupz4SRTw87pSzrnU1AnpbKNNrA7ifA2r8vC1NnzUiryQk0jYbXmviKCO0n9Q3V1ksx5yd6IGTOA5I5eAoInC2l6QjfwBuFg0xFJVi4mgNPdoJHOiZ23pgR0iMzq/CFBQtftjhgxCNbCgHWQqaUkMauGbUSrky2OCX7q0tg09B4ahIxgqJyF89OKUTIW2RttVdRMGEbNx2jrHTJcuHmyjCz2JTppEJmboL1R5HtMAGdD1kBXiZ23tgR0hIwcomaL3dA/5kKqWgRraqjoTzrn0um4rYJtglWROLiQjXKaqHXWwDpAJykhQ6KDasehWN1Ek70Vp6eRFsjm+PKTYqWF4tAFTABnIdWQF4mdl2pRUwI6UqCHqEU3DesABJO+2R3WTSJHIh1HJSRmqWwztRNfFHLa9eE/45JZg68qdT+EDHeaLbMCsrfc4CRMzvTCKNLSWoMX4aRMAGe2wlK4TuW8JvkcXVeXgE7CKoeoJsQhBGvCoeKQ0il5FZKYJbJNOKeTCE8wUk4hibgggcZ3FXkzUmBaDInI21ER2RviMSknyUgW57d+EBMHGiUwAZzZCRvqOvGuolYkoGuoFqvUGIrCTABXKYoNUT5V5A2yHBeWPavMOq1pCeiCKtuk6ZIAbh0TwDVp+5lmJ95V1EoEdCZQTX9NXx8xAVzTN6NZgcTDZrMwXzMCjED9EKjqsLl+1WDNjEDrIcDO23ptzjVeIgiw8y6RhuRqtB4C7Lyt1+Zc4yWCQOLVZl1vJqDTSPCZEagPAqmdl8xlArr6NBprZQQIgdTDZrEVUGz5kz/MIGGURgcT0AkY+A8jUFMEUjuvtZ83wIHJaiagq2nbuYTrDfqu5HreBm2IqIVNi6mrFvankJnaeUWUVRvtzWuyQUdgJqBL0SI1KSIdu8NgIakvsV9NKtlyQlM7r468gilDMGggdsqZCUVyYPowAV3jRET6d4bLZfWZrMW/keuqq9jdNAn51YvgV4upaxGqE0dFxc6rnVSehdcKp6XRdKKjkUneGtm2RCBz5qWEQGrnZQI6dzdwEsZJ8jZKazACOtNsMU809vs67sl23MBfwA0NarhtEtKBIs+znxU8dZ11yCPFToyce43D9QkCPGvYb9is6+PWJe7dBINyZGDWo5mJ9VI7rx4268grhsxi+GyTzhGfs0jH4TMT0OFwFYeq9SSgK3bbndnswLr/e89TMHC2Rw61Z4aR7aAPCoKRjhzNSZA3mc97yPaco2VvGS+pXoA+dMS+gQzyculhvw93l8N4dFJfgsGlRayX2HmZgK4IlZDj1ZOAzpzzTsaaiGZheGeXdAvBvaU8xI8gz+E8Pjd+ZYjRJFuEMeufaAnQhzxfGWT06KaRgKSz9FFgJrUGsV5i52UCumFkf6yAHK+RCOjM/l6362hSPRC0QBR1DwH00eghwonVkNmfYHDpEOsldl5qYyagQ06qROR4pmfUgYDOVO+5TkDcp8sqgjybDLCEc+OIkFgBqR4RwxdEdMaIOkk82FMgqK+Fk/o4cgTBYNdOnAKM7obdgvNPDe4jyuiqN9I5lfMyAR1xS5vNGEKO18gEdEmJ+6wqu+s7BmvEMJxeTE6yPatIJaR6aOeaMT1fxwXAzDiEcslHEQwuEWK9xEwaTEBnd0e+aiAEKAq3GLFeqsgb1GS4qCy+4zVXmXUaE9AFocbp1UCgFYn1Eu8qYgK6anQ1llE1BFqYWC/xsLlqoLMgRoARqAiBqg6bK7KECzMCjEAiBNh5E8HFmRmBxkGAnbdx2oItYQQSIcDOmwguzswINA4CiVebtelMQKeR4DMjUB8EUjsvmcsEdPVpNNbKCBACqYfNYiug2PInf5hBwiiNDiagEzDwH0agpgikdl5rP2+AA5PVzUlA54c37UX12QBOWR0/jg/J5yeW06qHgKMdqie2kSWldl4RZTHSOjbjq8irI3AyAjr9w3M8byxAxB6VGJhK1oTg7WPkaNXSFcOcmmTRdbSx67f2xtZEIQttIARSO6+OvIIpQzBoYK2UM1P9tFPHJqAzyNHGMwPQWaVemMXdP6PHvK+C2cIQbu9Oc0iHscyrO/EZbmCfUQwTyHYx3e2zRS5NNatSxoVVVWQGCKl7OwTYVcPkip1XO6k8o6XageX0N5XpXT25VOX8CmV6e5G+ZTeyJ5nHLBzDvZzDw9XTY0qv2zV24MHclO/Lqm42seKaIZDaeatPQKfriG/roSJk162RCWIuY8w3Hfc09MVIE0iShiLWUIcuwpC5Wby0GwagFzYrFUKRQy6muO9FJtLnJJTzkqyJjPKPkOEe0spo5OCQoh/XW1MFNZz3kL7pKJZkuJ9UVlD+oHSXTcpmOSqhMi6scDJU2NgB/SUlL6DOSYjpLLTd7RUTeyago7kuRVwxfFbX6l5H4igCOpscrQ/gUBnicSxR0wWQllmtiiQqGMmnrOiLHQhfDrnBNLzFuAm9TEwOAIIPKpT7GPXEIkFDTsUxbQ916iFYp4fBqKt3tBM7u12ZYvcY9BARm59u7LBDxSz0biZ2iKSyKL+bVC6hHLJrPAfFbnrZBmPlrINXb2xiOhsW11Vc7GkENqX6QlAZl+gGuk0ceWtFQKfJ0cZx2Ddg86vEgCqAtMwsqYjORPQVUXcYNK+ama261zFJ0LCbjxVz0EM8b0TSRi+jTh2tZeSaPm/P2XPjbuZEI3/nKPTOKJLzpLL8COIIkKRyPKRyXlQddfDT65ERo40damJiP3sMRqcU9sh4kt+fBxwaSKrb7nQrIg4zanyT2HlrTUDXtX8cckTwZkSbyjHAhhnE6Dt6DAoY5bK9m7GpanyoYVsQCZqgUcKFNLFwluvBOKWO7DDMWBSnciEqfBRiLFiVXf86QWJZ2gjXuVpyXGLDb+MQ0wVIiMC+R60L0AZ+pMeU2IeWCdBT5+TEzkv21paArgt24ri0OGR+XZSCJM0NrCJAG8AoNxhIe1oFPVpvBKEZUcBqEjSLXtVD7IaBr9+Y72vZcc5JZXnyK1I5T7rXpqLN3Yovoz4YsKJZDEOV/D6/NYmoN6xwOJ/V9QjsmYDu8pugiejmLr8hWsm58kzfAWMyzX3xc+HChRgtKbOsziPBGy4pia+LcAX1EPF9q6Fk39kMOLjfYkuVLwUwo5xZNraeMJI1Q2AkCRoupGWKOEzGhTOrk7qJ3TpgrMc9TDZ0hF4mleXOr0nl3Olem3IwZv2rCp2CHF3bHAcrlI+T3MxApyUDp94w7jenD62v8TAS+82IOmKfGbT/HaWoMob4RrlMzKTBBHTVa7pSfwcMrZtJsDhXPd3VkYSLPBs74exgOZzNsTrKwqVQFG4xArqOcESSPRWRln4u2Yb/06KzKI43eMUEdC4ssbMN0RC+bIVdVwa+TYKAJKDrBfPbvyTlmzFvYudlArrKm5kiLi1m0gq7tVBVudjWlMAEdK3Z7lxrRqCZEUi12tzMFWbbGYGlggA771JpSa5HyyEQOee9gqtOc/+zAL++CPDx5ZbDhyvMCDQsApFfFZ351QJ0tLfB736+Da5ZTivHfDACjEAjIBAZef8bI+43Vy3Db3/oX7eXX/6Q4f/4D5Ox7L9pfRba2yPVxJLFmRgBRsBGINKr5q8AtOPMmM7uIw4B3eTJN+Fr69fDimtWuovzPSPACFSAQKwFK/njC6cW8VNI+hkG/a8y6LObgO7dt9+GSxcxhPPBCDACVUMgnvOiOvJP86NpcGgPr58Dk4UmAd2777ztKG/Kqtr1+QLcvHwD7D3vtLVq8l0YSLkTcP/y7TDh98xhT0g+v7KcVr3+4miHxewbtdUVy3nJEbEvOT/KI0QENq5FXrrHw01A55GBeSbub4erP+PzuX/CqQ/z+pV3p5Fed5r3fg723kw60dFxq6z3+QQ8QDbdLHc2eZ87y4TpNJ+Z11EyK3+u62hj+8CE0+7KdTSPvMXFfnFwie+82NLKR8VZR166Ef+JniB3EAmgKB0/JgGdWV5f3/7cPHz0CX2Owr2QgyPiGu+fu8OhT+cPPa/Kw+ufnIIHVjltDSqDU3H46bE5j57ZwpPwomrtoLKOdCsvOsw32+EB3IssnrvtsfLFs8+hw4V/vGdZeOaMwvfMHnivR72sUsmqts0urGppk7sdaqlrEWXHcl7RUbDjoV3WRzuvjrzyTBmUA1Nm12GW97um7H7ptUgjXdffeRfAw7vguEPvHJReAXjmLyU5XRzdJEvno2s69L15Dko381TzWhiibcEO/NifTsGr9LLSaXU+O+yrsy2NgkkSO+I5LwGLUs1PGgI6s7zfNTWmM30O9uHw9sGJCXjwqnZYicPYObIDd+R8i+7V50EaDlrp2+G4sJXKbIB9e7db+b6114iypAzJ6R79kyI8ZaZP7IJH4S64Q29P8cjV+rUeEkT6Sd9aePQtgIPfVbYKO935TFvddZD1ddg5gfbregsddhk7XwBOwi6lT2ACQF8aSIwVpgrDSFmBumPKcbSVD1YLAXVw6V15lY2nxDykjVWdZX8xyvn2Hy/2c3s32Nj7ltFY1uccz3k1CMbZHXmpR4j/lPeJSEzveErHj8wfUUnhBc48lHTwu0fgOx/Pw4d/n4evYCM/h5vzi3RPnyM5fL4dTijbKL9Si6cpeOzsVplveg9G2R/AczTHxbwiG55v687BaYy+sjzKfqoIfY+THjuPzm/JVc+seyHsDnj243PwFzgU7zuibZUynPlIf1AdVsH9j6M9rxyVLynUcwJpe6Q92OGvfhKum1b1Rl3ff2Utvtjs+jhxcunGzvf0i1n4/h2rUH9SWZR/C0zvOiexRNx/hnOTRHIcbRWMlbMOXr0fHgHYdrXd3mFtTO2mPwINcR8XexqBTSnsg8rY8rWexTzHcl7RN/EP9dmDkw/A3SPtcGC2G+56YSX0HlgJd46sgBdO5hEo6ahUAUJNOjBdKudVMsRjn2tMkpHB9azvyAjcaqWtgv59ebiybwN89pp2+OxWSRQmIgoJwMO+zsLTD+HcGdMWVm2BbehYpm7Ku7DpEXh6fRGe3odR+fgueAz2QH6TzEfPbVn2NcmgI+yZ1uOfL6QOm7ZC31uHoTRHNkzA+Is52EL2HD8Ch+hllME6U72vWQuPYZR/77w9DHbiROWN/JnDsG36FPTTekBSWZR//R44sAMdFitkfZLKUViP4zxFyyB89DWdHXXw0+uSAcirEtbGWjbpke0VE/u5o/DTtxT2EFLGZb/WtxjnyB9pUKWFIfQHj797bwQO7nhf3hh/7913HeQ2PivyyuQ28assKkZXZSRVE05tlHFfChX4R6kSj+lafHTiXAFuu+EhnJSegwsfYU8U92dEJks+5jXL6XQhgp45nmPDYLR7/Omj8Nz1Rfj6tnNwrXquDdDlSSj+r/7IM93rNH0tzvpG5bdu8WIhrA5wB+SfeRLuOzYHm+BJOHTvj+CvqAzJQQf659fzcC1dGwdFci3fshWfL2DHfuoXp+B+hEkf1vOEskR5bYcWRueEckw7zWsS5b6nNHG49FK91upFScwg2pMK07X6Y9XTSpQPo7Dfcu8WeEZgfxj72F/DbaQ7tL1IQX2OWJFXgIP2KXxSE9BR+bAPQeB+Tml0WOlzZ+A0dpgRFQXmSofhtPHczEvXdOiy8k7em9cLt2+FH771EDxxMAePGNHFWRbJ6UQ0xPdFQp2mfnEdUYevdOFC2qu74NlXAZ7SIwdl47Mn7Pqc2IHDR1U/R31C0oT+pLI8+SfgeRqpeNJxmO+y6aUx+yu/uX0/gCcomt3u0wa6Asp20877SJdOP7ELnsA1iU3kvD5ldJJwuBUb4HnVZpQuZERgf+tDOL1S2H+vS400IspYtmkdi3SO5byi4mihfptp4jl9jktAp9+QQWdTj5VHJErdIu22R2AIHoIbV7TD5/Bz30wGvkZgKfsEkIatVjqlufJRgnxO0Q5p7e7dCrdaaaQYD7q/Ng8jzwA8cYO/TpFNlFsFm7Zl4aWtmO/bBfiA0kiEeEa51HVUHVDfwzgKeIk6KYZZbeOen+PXPSRb1f1oN04nlGzSY9dHldH6dB7rfAckk+XOfwRWiTmvO70d3Db9EI5Y9t74aAYOX9I2+2PlrAPKv3QUrn90rSXjc1sBDv9MrkkQLnRofMXZvPe7jsR+C2xFcrqXrv8RbNfYR5WxcHXZUuP0yF1Fr51bgA3XtsOn8xKpe15oh867vaPtmZfL8JPch9DWRoNkIrHCv+JDl23w/P4RuCf3IGHNRwwEXnuwHXZ1noPj6CTNeczByC1r4f3H5mEPRtq6Hjjs3fTVM/AIvTjqakh1lXu9MEC+fssFPBbJMg86eQABXRwZYfJb5tkHBdiFQ/iHL9KKbvPWekGt6NW7Dv+K05x3vn6Xcy2jeWG1LI/lvNR/dB/69rrt8PrLI5YAffGFj2+EHw8N6Vvfs5bh+5ATBQKv5dvhjw4C3PPqPNyCKc2Mme43davDie3wm9vo24gc/O3FPPx+k+OJ5juOWMPmP8Rh8yeX69YEDoP5hhFgBCQCkQtW+Pt8KM+LKSxjxggwAg2EQOSw+Yu/AfDLC1fgdz6/DCI9vYEqxqYwAksdgchhMzFozPxaEtD5sWksdYC4foxAoyIQ6byNajjbxQi0OgI8Em71HsD1b1oE2HmbtunY8FZHgJ231XsA179pEWDnbdqmY8NbHQF23lbvAVz/pkWAnbdpm44Nb3UE2HlbvQdw/ZsWgf8HuB1yYym8AfYAAAAASUVORK5CYII=" alt="" /></p><p><span style="white-space:pre">	</span>其中每一個文件下都定義了一個以該文件名命名的類,MultilayerPerceptron是NeuralNetwork的子類,後者是整個NN模塊的核心,</p><p>NeuralNetworkFunctions專門定義實現了NN模塊中用到的數學計算公式;最後兩個文件則是分別封裝了NN模塊的訓練過程(TrainMultilayerPerceptron)</p><p>和預測過程(RunMultilayerPerceptron),這裏我們主要學習NeuralNetwork類及其實現</p><p><span style="white-space:pre">	NeuralNetwork</span>類中,包含了多個參量和成員方法,這裏列舉其中一些主要的:</p><p></p><p> <span style="white-space:pre">	</span>Mahout神經網絡模塊主要成員變量及獲取/配置方法</p><p><span style="font-size:14px;"></span><table border="1" cellspacing="0" cellpadding="0"> <tbody><tr>  <td valign="top"><p>成員變量</p></td>  <td valign="top"><p>獲取方法</p></td>  <td valign="top"><p>配置方法</p></td> </tr> <tr>  <td valign="top"><p>LearningRate</p></td>  <td valign="top"><p>getLearningRate()</p></td>  <td valign="top"><p>setLearningRate()</p></td> </tr> <tr>  <td valign="top"><p>MomentumWeight</p></td>  <td valign="top"><p>getMomentumWeight()</p></td>  <td valign="top"><p>setMomentumWeight()</p></td> </tr> <tr>  <td valign="top"><p>RegularizationWeight</p></td>  <td valign="top"><p>getRegularizationWeight()</p></td>  <td valign="top"><p>setRegularizationWeight()</p></td> </tr> <tr>  <td valign="top"><p>TrainingMethod</p></td>  <td valign="top"><p>getTrainingMethod()</p></td>  <td valign="top"><p>setTrainingMethod()</p></td> </tr> <tr>  <td valign="top"><p>CostFunction</p></td>  <td valign="top"><p>getCostFunction()</p></td>  <td valign="top"><p>setCostFunction()</p></td> </tr></tbody></table></p><p><span style="white-space:pre">	</span>Mahout神經網絡模塊主要成員方法及描述</p><div align="center"></div><table border="1" cellspacing="0" cellpadding="0" width="623"><tbody><tr><td valign="top"><p>成員方法</p></td><td valign="top"><p>描述</p></td></tr><tr><td valign="top"><p>addLayer(int size, boolean isFinalLayer, String squashingFuctionName)</p></td><td valign="top"><p>爲神經網絡模型添加新的網絡層,其中參數“size”表示當前層下的神經元個數;參數“isFinalLayer”表示是否當前層級爲神經網絡的最後一層;參數“squashingFunctionName”則表示當前層級下的激勵函數(又稱擠壓函數)</p></td></tr><tr><td valign="top"><p>trainOnline(Vector trainingInstance)</p></td><td valign="top"><p>在線訓練模型,輸入參數爲輸入特徵與實際輸出特徵形成的向量。</p></td></tr><tr><td valign="top"><p>getOutput(Vectoe instance)</p></td><td valign="top"><p>計算模型輸出,輸入參數爲輸入特徵與實際輸出特徵形成的向量</p></td></tr><tr><td valign="top"><p>setModelPath(String modelPath)</p></td><td valign="top"><p>設置模型路徑爲:modelPath</p></td></tr><tr><td valign="top"><p>writeModelToFile()</p></td><td valign="top"><p>將模型寫入已指定的modelPath下</p><div></div></td></tr></tbody></table><p><span style="white-space:pre">	</span>trainOnline()方法實現了模型訓練過程,看一下它的內部:</p><p><pre name="code" class="java">  public void trainOnline(Vector trainingInstance) {
    Matrix[] matrices = trainByInstance(trainingInstance);
    updateWeightMatrices(matrices);
  }
即先執行trainByInstance(),將結果傳入matrices,再執行updateWeightMatrices(matrices),下面來到trainByInstance:

 public Matrix[] trainByInstance(Vector trainingInstance) {
    // validate training instance
    int inputDimension = layerSizeList.get(0) - 1;
    int outputDimension = layerSizeList.get(this.layerSizeList.size() - 1);
    Preconditions.checkArgument(inputDimension + outputDimension == trainingInstance.size(),
        String.format("The dimension of training instance is %d, but requires %d.", trainingInstance.size(),
            inputDimension + outputDimension));


    if (trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
      return trainByInstanceGradientDescent(trainingInstance);
    }
    throw new IllegalArgumentException("Training method is not supported.");
  }
在這個方法中,輸入參數trainingInstance 的維數等於輸入特徵維數與輸出特徵維數之和,函數接收參數後,首先根據類成員變量layerSizeList找到神經

網絡中第一層(輸入層)節點數與最後一層(輸出層)節點數,此後checkArgument()執行判斷確保輸入特徵與輸出特徵之和等於傳入的參數特徵數,經過這一步驟後的訓練樣本,以GRADIENT_DESCENT訓練方法被模型訓練(目前該模塊僅支持這一種訓練方法),返回trainByInstanceGradientDescent():

private Matrix[] trainByInstanceGradientDescent(Vector trainingInstance) {
    int inputDimension = layerSizeList.get(0) - 1;

    Vector inputInstance = new DenseVector(layerSizeList.get(0));
    inputInstance.set(0, 1); // add bias
    for (int i = 0; i < inputDimension; ++i) {
      inputInstance.set(i + 1, trainingInstance.get(i));
    }

    Vector labels =
        trainingInstance.viewPart(inputInstance.size() - 1, trainingInstance.size() - inputInstance.size() + 1);

    // initialize weight update matrices
    Matrix[] weightUpdateMatrices = new Matrix[weightMatrixList.size()];
    for (int m = 0; m < weightUpdateMatrices.length; ++m) {
      weightUpdateMatrices[m] =
          new DenseMatrix(weightMatrixList.get(m).rowSize(), weightMatrixList.get(m).columnSize());
    }

    List<Vector> internalResults = getOutputInternal(inputInstance);

    Vector deltaVec = new DenseVector(layerSizeList.get(layerSizeList.size() - 1));
    Vector output = internalResults.get(internalResults.size() - 1);

    final DoubleFunction derivativeSquashingFunction =
        NeuralNetworkFunctions.getDerivativeDoubleFunction(squashingFunctionList.get(squashingFunctionList.size() - 1));

    final DoubleDoubleFunction costFunction =
        NeuralNetworkFunctions.getDerivativeDoubleDoubleFunction(costFunctionName);

    Matrix lastWeightMatrix = weightMatrixList.get(weightMatrixList.size() - 1);

    for (int i = 0; i < deltaVec.size(); ++i) {
      double costFuncDerivative = costFunction.apply(labels.get(i), output.get(i + 1));
      // Add regularization
      costFuncDerivative += regularizationWeight * lastWeightMatrix.viewRow(i).zSum();
      deltaVec.set(i, costFuncDerivative);
      deltaVec.set(i, deltaVec.get(i) * derivativeSquashingFunction.apply(output.get(i + 1)));
    }

    // Start from previous layer of output layer
    for (int layer = layerSizeList.size() - 2; layer >= 0; --layer) {
      deltaVec = backPropagate(layer, deltaVec, internalResults, weightUpdateMatrices[layer]);
    }

    prevWeightUpdatesList = Arrays.asList(weightUpdateMatrices);

    return weightUpdateMatrices;
  }
這一部分代碼相對較多,我們逐塊分析:

首先該方法解析輸入參量,將輸入特徵和輸出特徵分離後分別寫入inputInstance和labels,之後初始化一個weightUpdateMatrices,

然後通過getOutputInternal()方法獲得輸出,將輸出值的輸入特徵和輸出特徵分別寫入deltaVal 和output;分別獲取當前網絡層級下的

derivativeSquashingFunction(新建NN實例時就定義好了的)、costFuction(新建NN實例時就定義好了的)以及lastWeightMatrices(初始化weightUpdateMatrices時定義的)

在這之後,逐個依據每一位的labels和output計算costFuncDerivative(默認爲MSE),再分別考慮regularizationWeight和SquashingFuction,得到最終的deltaVec.

完成這一步後,將此時得到的deltaVec與之前的各層網絡做誤差反向傳播(backPropagate()方法),以此更新deltaVec,最終返回跟新後的weightUpdataMatrices

  private Vector backPropagate(int currentLayerIndex, Vector nextLayerDelta,
                               List<Vector> outputCache, Matrix weightUpdateMatrix) {

    // Get layer related information
    final DoubleFunction derivativeSquashingFunction =
        NeuralNetworkFunctions.getDerivativeDoubleFunction(squashingFunctionList.get(currentLayerIndex));
    Vector curLayerOutput = outputCache.get(currentLayerIndex);
    Matrix weightMatrix = weightMatrixList.get(currentLayerIndex);
    Matrix prevWeightMatrix = prevWeightUpdatesList.get(currentLayerIndex);

    // Next layer is not output layer, remove the delta of bias neuron
    if (currentLayerIndex != layerSizeList.size() - 2) {
      nextLayerDelta = nextLayerDelta.viewPart(1, nextLayerDelta.size() - 1);
    }

    Vector delta = weightMatrix.transpose().times(nextLayerDelta);

    delta = delta.assign(curLayerOutput, new DoubleDoubleFunction() {
      @Override
      public double apply(double deltaElem, double curLayerOutputElem) {
        return deltaElem * derivativeSquashingFunction.apply(curLayerOutputElem);
      }
    });

    // Update weights
    for (int i = 0; i < weightUpdateMatrix.rowSize(); ++i) {
      for (int j = 0; j < weightUpdateMatrix.columnSize(); ++j) {
        weightUpdateMatrix.set(i, j, -learningRate * nextLayerDelta.get(i) *
            curLayerOutput.get(j) + this.momentumWeight * prevWeightMatrix.get(i, j));
      }
    }

    return delta;
  }

以上爲mlp中核心算法的實現,其中上文未提及的一些方法實現例如如何計算costFuncDerivative、如何使用SquashingFuction以及如何backPropagate等,大家可以查閱

NN相關書籍資料,這裏的實現與書籍上介紹的算法完全一致,因此不再贅述。這裏想要說明的是關於模型的序列化和反序列化過程,因爲這一步驟是一個模型進行分佈式擴展的必要步驟:

在mlp模塊中,模型的序列化和反序列化通過write()和readFields()方法來實現,源碼如下:

 public void write(DataOutput output) throws IOException {
    // Write model type
    WritableUtils.writeString(output, modelType);
    // Write learning rate
    output.writeDouble(learningRate);
    // Write model path
    if (modelPath != null) {
      WritableUtils.writeString(output, modelPath);
    } else {
      WritableUtils.writeString(output, "null");
    }

    // Write regularization weight
    output.writeDouble(regularizationWeight);
    // Write momentum weight
    output.writeDouble(momentumWeight);
    // Write cost function
    WritableUtils.writeString(output, costFunctionName);

    // Write layer size list
    output.writeInt(layerSizeList.size());
    for (Integer aLayerSizeList : layerSizeList) {
      output.writeInt(aLayerSizeList);
    }

    WritableUtils.writeEnum(output, trainingMethod);

    // Write squashing functions
    output.writeInt(squashingFunctionList.size());
    for (String aSquashingFunctionList : squashingFunctionList) {
      WritableUtils.writeString(output, aSquashingFunctionList);
    }

    // Write weight matrices
    output.writeInt(this.weightMatrixList.size());
    for (Matrix aWeightMatrixList : weightMatrixList) {
      MatrixWritable.writeMatrix(output, aWeightMatrixList);
    }
  }

  /**
   * Read the fields of the model from input.
   * 
   * @param input The input instance.
   * @throws IOException
   */
  public void readFields(DataInput input) throws IOException {
    // Read model type
    modelType = WritableUtils.readString(input);
    if (!modelType.equals(this.getClass().getSimpleName())) {
      throw new IllegalArgumentException("The specified location does not contains the valid NeuralNetwork model.");
    }
    // Read learning rate
    learningRate = input.readDouble();
    // Read model path
    modelPath = WritableUtils.readString(input);
    if (modelPath.equals("null")) {
      modelPath = null;
    }

    // Read regularization weight
    regularizationWeight = input.readDouble();
    // Read momentum weight
    momentumWeight = input.readDouble();

    // Read cost function
    costFunctionName = WritableUtils.readString(input);

    // Read layer size list
    int numLayers = input.readInt();
    layerSizeList = new ArrayList<>();
    for (int i = 0; i < numLayers; i++) {
      layerSizeList.add(input.readInt());
    }

    trainingMethod = WritableUtils.readEnum(input, TrainingMethod.class);

    // Read squash functions
    int squashingFunctionSize = input.readInt();
    squashingFunctionList = new ArrayList<>();
    for (int i = 0; i < squashingFunctionSize; i++) {
      squashingFunctionList.add(WritableUtils.readString(input));
    }

    // Read weights and construct matrices of previous updates
    int numOfMatrices = input.readInt();
    weightMatrixList = new ArrayList<>();
    prevWeightUpdatesList = new ArrayList<>();
    for (int i = 0; i < numOfMatrices; i++) {
      Matrix matrix = MatrixWritable.readMatrix(input);
      weightMatrixList.add(matrix);
      prevWeightUpdatesList.add(new DenseMatrix(matrix.rowSize(), matrix.columnSize()));
    }
  }




發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章