下面的范例使用TensorFlow的高阶API实现线性回归模型。
TensorFlow的高阶API主要为tf.keras.models提供的模型的类接口。
使用Keras接口有以下3种方式构建模型:使用Sequential按层顺序构建模型,使用函数式API构建任意结构模型,继承Model基类构建自定义模型。
此处分别演示使用Sequential按层顺序构建模型以及继承Model基类构建自定义模型。
import tensorflow as tf from tensorflow.keras import models,layers,optimizers # 样本数量 n = 800 # 生成测试用数据集 X = tf.random.uniform([n,2],minval=-10,maxval=10) w0 = tf.constant([[2.0],[-1.0]]) b0 = tf.constant(3.0) Y = X@w0 + b0 + tf.random.normal([n,1],mean = 0.0,stddev= 2.0) # @表示矩阵乘法,增加正态扰动 tf.keras.backend.clear_session() linear = models.Sequential() linear.add(layers.Dense(1,input_shape =(2,))) linear.summary() ### 使用fit方法进行训练 linear.compile(optimizer="adam",loss="mse",metrics=["mae"]) linear.fit(X,Y,batch_size = 20,epochs = 200) tf.print("w = ",linear.layers[0].kernel) tf.print("b = ",linear.layers[0].bias)
结果:
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 1) 3 ================================================================= Total params: 3 Trainable params: 3 Non-trainable params: 0 _________________________________________________________________ Epoch 1/200 40/40 [==============================] - 0s 908us/step - loss: 195.5055 - mae: 11.7040 Epoch 2/200 40/40 [==============================] - 0s 870us/step - loss: 188.2559 - mae: 11.4891 Epoch 3/200 40/40 [==============================] - 0s 820us/step - loss: 181.3084 - mae: 11.2766 Epoch 4/200 40/40 [==============================] - 0s 859us/step - loss: 174.4538 - mae: 11.0680 Epoch 5/200 40/40 [==============================] - 0s 886us/step - loss: 167.8749 - mae: 10.8582 Epoch 6/200 40/40 [==============================] - 0s 912us/step - loss: 161.5035 - mae: 10.6533 Epoch 7/200 40/40 [==============================] - 0s 916us/step - loss: 155.3012 - mae: 10.4504 Epoch 8/200 40/40 [==============================] - 0s 839us/step - loss: 149.3520 - mae: 10.2490 Epoch 9/200 40/40 [==============================] - 0s 977us/step - loss: 143.5773 - mae: 10.0487 Epoch 10/200 40/40 [==============================] - 0s 951us/step - loss: 137.9654 - mae: 9.8543 Epoch 11/200 40/40 [==============================] - 0s 964us/step - loss: 132.5708 - mae: 9.6616 Epoch 12/200 40/40 [==============================] - 0s 876us/step - loss: 127.3686 - mae: 9.4716 Epoch 13/200 40/40 [==============================] - 0s 885us/step - loss: 122.3309 - mae: 9.2796 Epoch 14/200 40/40 [==============================] - 0s 901us/step - loss: 117.4739 - mae: 9.0935 Epoch 15/200 40/40 [==============================] - 0s 919us/step - loss: 112.7674 - mae: 8.9095 Epoch 16/200 40/40 [==============================] - 0s 1ms/step - loss: 108.2400 - mae: 8.7304 Epoch 17/200 40/40 [==============================] - 0s 1ms/step - loss: 103.8868 - mae: 8.5522 Epoch 18/200 40/40 [==============================] - 0s 955us/step - loss: 99.6424 - mae: 8.3771 Epoch 19/200 40/40 [==============================] - 0s 951us/step - loss: 95.6005 - mae: 8.2044 Epoch 20/200 40/40 [==============================] - 0s 939us/step - loss: 91.7217 - mae: 8.0324 Epoch 21/200 40/40 [==============================] - 0s 1ms/step - loss: 87.9180 - mae: 7.8633 Epoch 22/200 40/40 [==============================] - 0s 1ms/step - loss: 84.2936 - mae: 7.6975 Epoch 23/200 40/40 [==============================] - 0s 1ms/step - loss: 80.7858 - mae: 7.5372 Epoch 24/200 40/40 [==============================] - 0s 891us/step - loss: 77.4177 - mae: 7.3785 Epoch 25/200 40/40 [==============================] - 0s 902us/step - loss: 74.1665 - mae: 7.2210 Epoch 26/200 40/40 [==============================] - 0s 876us/step - loss: 71.0455 - mae: 7.0657 Epoch 27/200 40/40 [==============================] - 0s 892us/step - loss: 68.0396 - mae: 6.9119 Epoch 28/200 40/40 [==============================] - 0s 898us/step - loss: 65.1385 - mae: 6.7610 Epoch 29/200 40/40 [==============================] - 0s 944us/step - loss: 62.3531 - mae: 6.6115 Epoch 30/200 40/40 [==============================] - 0s 1ms/step - loss: 59.6815 - mae: 6.4647 Epoch 31/200 40/40 [==============================] - 0s 1ms/step - loss: 57.0783 - mae: 6.3193 Epoch 32/200 40/40 [==============================] - 0s 978us/step - loss: 54.6050 - mae: 6.1775 Epoch 33/200 40/40 [==============================] - 0s 940us/step - loss: 52.2259 - mae: 6.0359 Epoch 34/200 40/40 [==============================] - 0s 966us/step - loss: 49.9196 - mae: 5.8980 Epoch 35/200 40/40 [==============================] - 0s 964us/step - loss: 47.7187 - mae: 5.7628 Epoch 36/200 40/40 [==============================] - 0s 1ms/step - loss: 45.6023 - mae: 5.6286 Epoch 37/200 40/40 [==============================] - 0s 953us/step - loss: 43.5680 - mae: 5.4965 Epoch 38/200 40/40 [==============================] - 0s 978us/step - loss: 41.6182 - mae: 5.3673 Epoch 39/200 40/40 [==============================] - 0s 1ms/step - loss: 39.7323 - mae: 5.2402 Epoch 40/200 40/40 [==============================] - 0s 976us/step - loss: 37.9372 - mae: 5.1159 Epoch 41/200 40/40 [==============================] - 0s 989us/step - loss: 36.2184 - mae: 4.9935 Epoch 42/200 40/40 [==============================] - 0s 964us/step - loss: 34.5556 - mae: 4.8724 Epoch 43/200 40/40 [==============================] - 0s 978us/step - loss: 32.9704 - mae: 4.7550 Epoch 44/200 40/40 [==============================] - 0s 954us/step - loss: 31.4466 - mae: 4.6392 Epoch 45/200 40/40 [==============================] - 0s 1ms/step - loss: 29.9887 - mae: 4.5273 Epoch 46/200 40/40 [==============================] - 0s 1ms/step - loss: 28.5938 - mae: 4.4169 Epoch 47/200 40/40 [==============================] - 0s 944us/step - loss: 27.2567 - mae: 4.3116 Epoch 48/200 40/40 [==============================] - 0s 874us/step - loss: 25.9801 - mae: 4.2037 Epoch 49/200 40/40 [==============================] - 0s 875us/step - loss: 24.7709 - mae: 4.1004 Epoch 50/200 40/40 [==============================] - 0s 843us/step - loss: 23.5911 - mae: 3.9987 Epoch 51/200 40/40 [==============================] - 0s 880us/step - loss: 22.4801 - mae: 3.8986 Epoch 52/200 40/40 [==============================] - 0s 862us/step - loss: 21.4129 - mae: 3.8020 Epoch 53/200 40/40 [==============================] - 0s 930us/step - loss: 20.4039 - mae: 3.7072 Epoch 54/200 40/40 [==============================] - 0s 921us/step - loss: 19.4387 - mae: 3.6129 Epoch 55/200 40/40 [==============================] - 0s 929us/step - loss: 18.5113 - mae: 3.5211 Epoch 56/200 40/40 [==============================] - 0s 958us/step - loss: 17.6301 - mae: 3.4325 Epoch 57/200 40/40 [==============================] - 0s 857us/step - loss: 16.7977 - mae: 3.3455 Epoch 58/200 40/40 [==============================] - 0s 924us/step - loss: 16.0002 - mae: 3.2620 Epoch 59/200 40/40 [==============================] - 0s 906us/step - loss: 15.2526 - mae: 3.1796 Epoch 60/200 40/40 [==============================] - 0s 989us/step - loss: 14.5282 - mae: 3.1000 Epoch 61/200 40/40 [==============================] - 0s 1ms/step - loss: 13.8489 - mae: 3.0228 Epoch 62/200 40/40 [==============================] - 0s 957us/step - loss: 13.2086 - mae: 2.9496 Epoch 63/200 40/40 [==============================] - 0s 1ms/step - loss: 12.5944 - mae: 2.8770 Epoch 64/200 40/40 [==============================] - 0s 1ms/step - loss: 12.0144 - mae: 2.8087 Epoch 65/200 40/40 [==============================] - 0s 939us/step - loss: 11.4699 - mae: 2.7409 Epoch 66/200 40/40 [==============================] - 0s 950us/step - loss: 10.9486 - mae: 2.6764 Epoch 67/200 40/40 [==============================] - 0s 922us/step - loss: 10.4627 - mae: 2.6140 Epoch 68/200 40/40 [==============================] - 0s 937us/step - loss: 10.0007 - mae: 2.5530 Epoch 69/200 40/40 [==============================] - 0s 1ms/step - loss: 9.5686 - mae: 2.4958 Epoch 70/200 40/40 [==============================] - 0s 926us/step - loss: 9.1566 - mae: 2.4412 Epoch 71/200 40/40 [==============================] - 0s 990us/step - loss: 8.7749 - mae: 2.3897 Epoch 72/200 40/40 [==============================] - 0s 1ms/step - loss: 8.4119 - mae: 2.3410 Epoch 73/200 40/40 [==============================] - 0s 1ms/step - loss: 8.0721 - mae: 2.2930 Epoch 74/200 40/40 [==============================] - 0s 996us/step - loss: 7.7548 - mae: 2.2490 Epoch 75/200 40/40 [==============================] - 0s 1ms/step - loss: 7.4565 - mae: 2.2054 Epoch 76/200 40/40 [==============================] - 0s 1ms/step - loss: 7.1764 - mae: 2.1642 Epoch 77/200 40/40 [==============================] - 0s 987us/step - loss: 6.9172 - mae: 2.1252 Epoch 78/200 40/40 [==============================] - 0s 1ms/step - loss: 6.6718 - mae: 2.0881 Epoch 79/200 40/40 [==============================] - 0s 1ms/step - loss: 6.4435 - mae: 2.0517 Epoch 80/200 40/40 [==============================] - 0s 1ms/step - loss: 6.2325 - mae: 2.0181 Epoch 81/200 40/40 [==============================] - 0s 946us/step - loss: 6.0333 - mae: 1.9845 Epoch 82/200 40/40 [==============================] - 0s 934us/step - loss: 5.8515 - mae: 1.9533 Epoch 83/200 40/40 [==============================] - 0s 922us/step - loss: 5.6774 - mae: 1.9230 Epoch 84/200 40/40 [==============================] - 0s 941us/step - loss: 5.5195 - mae: 1.8950 Epoch 85/200 40/40 [==============================] - 0s 1ms/step - loss: 5.3701 - mae: 1.8676 Epoch 86/200 40/40 [==============================] - 0s 1ms/step - loss: 5.2337 - mae: 1.8420 Epoch 87/200 40/40 [==============================] - 0s 1ms/step - loss: 5.1067 - mae: 1.8188 Epoch 88/200 40/40 [==============================] - 0s 894us/step - loss: 4.9888 - mae: 1.7968 Epoch 89/200 40/40 [==============================] - 0s 909us/step - loss: 4.8797 - mae: 1.7761 Epoch 90/200 40/40 [==============================] - 0s 876us/step - loss: 4.7784 - mae: 1.7572 Epoch 91/200 40/40 [==============================] - 0s 872us/step - loss: 4.6857 - mae: 1.7381 Epoch 92/200 40/40 [==============================] - 0s 866us/step - loss: 4.5981 - mae: 1.7221 Epoch 93/200 40/40 [==============================] - 0s 928us/step - loss: 4.5178 - mae: 1.7055 Epoch 94/200 40/40 [==============================] - 0s 868us/step - loss: 4.4441 - mae: 1.6920 Epoch 95/200 40/40 [==============================] - 0s 931us/step - loss: 4.3759 - mae: 1.6776 Epoch 96/200 40/40 [==============================] - 0s 963us/step - loss: 4.3143 - mae: 1.6650 Epoch 97/200 40/40 [==============================] - 0s 971us/step - loss: 4.2540 - mae: 1.6532 Epoch 98/200 40/40 [==============================] - 0s 914us/step - loss: 4.2015 - mae: 1.6427 Epoch 99/200 40/40 [==============================] - 0s 874us/step - loss: 4.1508 - mae: 1.6330 Epoch 100/200 40/40 [==============================] - 0s 897us/step - loss: 4.1059 - mae: 1.6243 Epoch 101/200 40/40 [==============================] - 0s 884us/step - loss: 4.0636 - mae: 1.6162 Epoch 102/200 40/40 [==============================] - 0s 971us/step - loss: 4.0239 - mae: 1.6081 Epoch 103/200 40/40 [==============================] - 0s 918us/step - loss: 3.9885 - mae: 1.6012 Epoch 104/200 40/40 [==============================] - 0s 990us/step - loss: 3.9542 - mae: 1.5946 Epoch 105/200 40/40 [==============================] - 0s 919us/step - loss: 3.9245 - mae: 1.5892 Epoch 106/200 40/40 [==============================] - 0s 872us/step - loss: 3.8949 - mae: 1.5834 Epoch 107/200 40/40 [==============================] - 0s 879us/step - loss: 3.8686 - mae: 1.5779 Epoch 108/200 40/40 [==============================] - 0s 872us/step - loss: 3.8441 - mae: 1.5735 Epoch 109/200 40/40 [==============================] - 0s 1ms/step - loss: 3.8221 - mae: 1.5693 Epoch 110/200 40/40 [==============================] - 0s 941us/step - loss: 3.7991 - mae: 1.5651 Epoch 111/200 40/40 [==============================] - 0s 958us/step - loss: 3.7793 - mae: 1.5617 Epoch 112/200 40/40 [==============================] - 0s 888us/step - loss: 3.7607 - mae: 1.5583 Epoch 113/200 40/40 [==============================] - 0s 834us/step - loss: 3.7446 - mae: 1.5555 Epoch 114/200 40/40 [==============================] - 0s 872us/step - loss: 3.7285 - mae: 1.5529 Epoch 115/200 40/40 [==============================] - 0s 878us/step - loss: 3.7146 - mae: 1.5499 Epoch 116/200 40/40 [==============================] - 0s 944us/step - loss: 3.7016 - mae: 1.5476 Epoch 117/200 40/40 [==============================] - 0s 949us/step - loss: 3.6883 - mae: 1.5449 Epoch 118/200 40/40 [==============================] - 0s 939us/step - loss: 3.6753 - mae: 1.5428 Epoch 119/200 40/40 [==============================] - 0s 859us/step - loss: 3.6651 - mae: 1.5408 Epoch 120/200 40/40 [==============================] - 0s 876us/step - loss: 3.6544 - mae: 1.5387 Epoch 121/200 40/40 [==============================] - 0s 860us/step - loss: 3.6459 - mae: 1.5371 Epoch 122/200 40/40 [==============================] - 0s 938us/step - loss: 3.6357 - mae: 1.5357 Epoch 123/200 40/40 [==============================] - 0s 918us/step - loss: 3.6284 - mae: 1.5345 Epoch 124/200 40/40 [==============================] - 0s 890us/step - loss: 3.6212 - mae: 1.5334 Epoch 125/200 40/40 [==============================] - 0s 853us/step - loss: 3.6131 - mae: 1.5318 Epoch 126/200 40/40 [==============================] - 0s 856us/step - loss: 3.6067 - mae: 1.5307 Epoch 127/200 40/40 [==============================] - 0s 1ms/step - loss: 3.6014 - mae: 1.5297 Epoch 128/200 40/40 [==============================] - 0s 990us/step - loss: 3.5953 - mae: 1.5289 Epoch 129/200 40/40 [==============================] - 0s 955us/step - loss: 3.5898 - mae: 1.5278 Epoch 130/200 40/40 [==============================] - 0s 929us/step - loss: 3.5857 - mae: 1.5270 Epoch 131/200 40/40 [==============================] - 0s 878us/step - loss: 3.5823 - mae: 1.5267 Epoch 132/200 40/40 [==============================] - 0s 925us/step - loss: 3.5767 - mae: 1.5255 Epoch 133/200 40/40 [==============================] - 0s 1ms/step - loss: 3.5735 - mae: 1.5246 Epoch 134/200 40/40 [==============================] - 0s 950us/step - loss: 3.5699 - mae: 1.5239 Epoch 135/200 40/40 [==============================] - 0s 855us/step - loss: 3.5664 - mae: 1.5233 Epoch 136/200 40/40 [==============================] - 0s 869us/step - loss: 3.5637 - mae: 1.5228 Epoch 137/200 40/40 [==============================] - 0s 920us/step - loss: 3.5611 - mae: 1.5224 Epoch 138/200 40/40 [==============================] - 0s 946us/step - loss: 3.5586 - mae: 1.5218 Epoch 139/200 40/40 [==============================] - 0s 864us/step - loss: 3.5570 - mae: 1.5216 Epoch 140/200 40/40 [==============================] - 0s 1ms/step - loss: 3.5544 - mae: 1.5208 Epoch 141/200 40/40 [==============================] - 0s 990us/step - loss: 3.5522 - mae: 1.5206 Epoch 142/200 40/40 [==============================] - 0s 914us/step - loss: 3.5508 - mae: 1.5200 Epoch 143/200 40/40 [==============================] - 0s 865us/step - loss: 3.5494 - mae: 1.5197 Epoch 144/200 40/40 [==============================] - 0s 867us/step - loss: 3.5487 - mae: 1.5194 Epoch 145/200 40/40 [==============================] - 0s 848us/step - loss: 3.5473 - mae: 1.5194 Epoch 146/200 40/40 [==============================] - 0s 920us/step - loss: 3.5453 - mae: 1.5188 Epoch 147/200 40/40 [==============================] - 0s 954us/step - loss: 3.5445 - mae: 1.5186 Epoch 148/200 40/40 [==============================] - 0s 958us/step - loss: 3.5443 - mae: 1.5188 Epoch 149/200 40/40 [==============================] - 0s 929us/step - loss: 3.5430 - mae: 1.5181 Epoch 150/200 40/40 [==============================] - 0s 919us/step - loss: 3.5430 - mae: 1.5186 Epoch 151/200 40/40 [==============================] - 0s 875us/step - loss: 3.5409 - mae: 1.5176 Epoch 152/200 40/40 [==============================] - 0s 931us/step - loss: 3.5425 - mae: 1.5177 Epoch 153/200 40/40 [==============================] - 0s 957us/step - loss: 3.5403 - mae: 1.5175 Epoch 154/200 40/40 [==============================] - 0s 967us/step - loss: 3.5403 - mae: 1.5172 Epoch 155/200 40/40 [==============================] - 0s 873us/step - loss: 3.5425 - mae: 1.5177 Epoch 156/200 40/40 [==============================] - 0s 905us/step - loss: 3.5402 - mae: 1.5173 Epoch 157/200 40/40 [==============================] - 0s 1ms/step - loss: 3.5395 - mae: 1.5172 Epoch 158/200 40/40 [==============================] - 0s 876us/step - loss: 3.5385 - mae: 1.5169 Epoch 159/200 40/40 [==============================] - 0s 877us/step - loss: 3.5383 - mae: 1.5167 Epoch 160/200 40/40 [==============================] - 0s 847us/step - loss: 3.5385 - mae: 1.5167 Epoch 161/200 40/40 [==============================] - 0s 846us/step - loss: 3.5375 - mae: 1.5165 Epoch 162/200 40/40 [==============================] - 0s 947us/step - loss: 3.5377 - mae: 1.5166 Epoch 163/200 40/40 [==============================] - 0s 986us/step - loss: 3.5371 - mae: 1.5165 Epoch 164/200 40/40 [==============================] - 0s 869us/step - loss: 3.5380 - mae: 1.5167 Epoch 165/200 40/40 [==============================] - 0s 875us/step - loss: 3.5402 - mae: 1.5169 Epoch 166/200 40/40 [==============================] - 0s 913us/step - loss: 3.5390 - mae: 1.5170 Epoch 167/200 40/40 [==============================] - 0s 926us/step - loss: 3.5389 - mae: 1.5163 Epoch 168/200 40/40 [==============================] - 0s 853us/step - loss: 3.5379 - mae: 1.5160 Epoch 169/200 40/40 [==============================] - 0s 925us/step - loss: 3.5380 - mae: 1.5159 Epoch 170/200 40/40 [==============================] - 0s 935us/step - loss: 3.5376 - mae: 1.5167 Epoch 171/200 40/40 [==============================] - 0s 873us/step - loss: 3.5371 - mae: 1.5164 Epoch 172/200 40/40 [==============================] - 0s 847us/step - loss: 3.5376 - mae: 1.5165 Epoch 173/200 40/40 [==============================] - 0s 874us/step - loss: 3.5383 - mae: 1.5167 Epoch 174/200 40/40 [==============================] - 0s 930us/step - loss: 3.5362 - mae: 1.5162 Epoch 175/200 40/40 [==============================] - 0s 960us/step - loss: 3.5386 - mae: 1.5165 Epoch 176/200 40/40 [==============================] - 0s 968us/step - loss: 3.5376 - mae: 1.5166 Epoch 177/200 40/40 [==============================] - 0s 986us/step - loss: 3.5373 - mae: 1.5164 Epoch 178/200 40/40 [==============================] - 0s 907us/step - loss: 3.5395 - mae: 1.5166 Epoch 179/200 40/40 [==============================] - 0s 911us/step - loss: 3.5375 - mae: 1.5161 Epoch 180/200 40/40 [==============================] - 0s 1ms/step - loss: 3.5377 - mae: 1.5165 Epoch 181/200 40/40 [==============================] - 0s 1ms/step - loss: 3.5367 - mae: 1.5164 Epoch 182/200 40/40 [==============================] - 0s 890us/step - loss: 3.5380 - mae: 1.5164 Epoch 183/200 40/40 [==============================] - 0s 926us/step - loss: 3.5373 - mae: 1.5167 Epoch 184/200 40/40 [==============================] - 0s 931us/step - loss: 3.5389 - mae: 1.5168 Epoch 185/200 40/40 [==============================] - 0s 839us/step - loss: 3.5371 - mae: 1.5158 Epoch 186/200 40/40 [==============================] - 0s 892us/step - loss: 3.5383 - mae: 1.5159 Epoch 187/200 40/40 [==============================] - 0s 915us/step - loss: 3.5371 - mae: 1.5163 Epoch 188/200 40/40 [==============================] - 0s 992us/step - loss: 3.5384 - mae: 1.5170 Epoch 189/200 40/40 [==============================] - 0s 913us/step - loss: 3.5376 - mae: 1.5160 Epoch 190/200 40/40 [==============================] - 0s 970us/step - loss: 3.5386 - mae: 1.5166 Epoch 191/200 40/40 [==============================] - 0s 954us/step - loss: 3.5398 - mae: 1.5163 Epoch 192/200 40/40 [==============================] - 0s 906us/step - loss: 3.5370 - mae: 1.5163 Epoch 193/200 40/40 [==============================] - 0s 892us/step - loss: 3.5371 - mae: 1.5166 Epoch 194/200 40/40 [==============================] - 0s 1ms/step - loss: 3.5389 - mae: 1.5167 Epoch 195/200 40/40 [==============================] - 0s 976us/step - loss: 3.5376 - mae: 1.5170 Epoch 196/200 40/40 [==============================] - 0s 925us/step - loss: 3.5371 - mae: 1.5164 Epoch 197/200 40/40 [==============================] - 0s 995us/step - loss: 3.5368 - mae: 1.5161 Epoch 198/200 40/40 [==============================] - 0s 957us/step - loss: 3.5380 - mae: 1.5161 Epoch 199/200 40/40 [==============================] - 0s 923us/step - loss: 3.5391 - mae: 1.5162 Epoch 200/200 40/40 [==============================] - 0s 899us/step - loss: 3.5368 - mae: 1.5160 w = [[2.00381827] [-0.98936516]] b = [2.9572618]
import tensorflow as tf from tensorflow.keras import models,layers,optimizers,losses,metrics # 打印时间分割线 @tf.function def printbar(): ts = tf.timestamp() today_ts = ts%(24*60*60) hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24) minite = tf.cast((today_ts%3600)//60,tf.int32) second = tf.cast(tf.floor(today_ts%60),tf.int32) def timeformat(m): if tf.strings.length(tf.strings.format("{}",m))==1: return(tf.strings.format("0{}",m)) else: return(tf.strings.format("{}",m)) timestring = tf.strings.join([timeformat(hour),timeformat(minite), timeformat(second)],separator = ":") tf.print("=========="*8,end = "") tf.print(timestring) # 样本数量 n = 800 # 生成测试用数据集 X = tf.random.uniform([n,2],minval=-10,maxval=10) w0 = tf.constant([[2.0],[-1.0]]) b0 = tf.constant(3.0) Y = X@w0 + b0 + tf.random.normal([n,1],mean = 0.0,stddev= 2.0) # @表示矩阵乘法,增加正态扰动 ds_train = tf.data.Dataset.from_tensor_slices((X[0:n*3//4,:],Y[0:n*3//4,:])) \ .shuffle(buffer_size = 1000).batch(20) \ .prefetch(tf.data.experimental.AUTOTUNE) \ .cache() ds_valid = tf.data.Dataset.from_tensor_slices((X[n*3//4:,:],Y[n*3//4:,:])) \ .shuffle(buffer_size = 1000).batch(20) \ .prefetch(tf.data.experimental.AUTOTUNE) \ .cache() tf.keras.backend.clear_session() class MyModel(models.Model): def __init__(self): super(MyModel, self).__init__() def build(self,input_shape): self.dense1 = layers.Dense(1) super(MyModel,self).build(input_shape) def call(self, x): y = self.dense1(x) return(y) model = MyModel() model.build(input_shape =(None,2)) model.summary() ### 自定义训练循环(专家教程) optimizer = optimizers.Adam() loss_func = losses.MeanSquaredError() train_loss = tf.keras.metrics.Mean(name='train_loss') train_metric = tf.keras.metrics.MeanAbsoluteError(name='train_mae') valid_loss = tf.keras.metrics.Mean(name='valid_loss') valid_metric = tf.keras.metrics.MeanAbsoluteError(name='valid_mae') @tf.function def train_step(model, features, labels): with tf.GradientTape() as tape: predictions = model(features) loss = loss_func(labels, predictions) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) train_loss.update_state(loss) train_metric.update_state(labels, predictions) @tf.function def valid_step(model, features, labels): predictions = model(features) batch_loss = loss_func(labels, predictions) valid_loss.update_state(batch_loss) valid_metric.update_state(labels, predictions) @tf.function def train_model(model,ds_train,ds_valid,epochs): for epoch in tf.range(1,epochs+1): for features, labels in ds_train: train_step(model,features,labels) for features, labels in ds_valid: valid_step(model,features,labels) logs = 'Epoch={},Loss:{},MAE:{},Valid Loss:{},Valid MAE:{}' if epoch%100 ==0: printbar() tf.print(tf.strings.format(logs, (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result()))) tf.print("w=",model.layers[0].kernel) tf.print("b=",model.layers[0].bias) tf.print("") train_loss.reset_states() valid_loss.reset_states() train_metric.reset_states() valid_metric.reset_states() train_model(model,ds_train,ds_valid,400)
结果:
Model: "my_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) multiple 3 ================================================================= Total params: 3 Trainable params: 3 Non-trainable params: 0 _________________________________________________________________ ================================================================================15:40:27 Epoch=100,Loss:7.5666852,MAE:2.1710279,Valid Loss:6.50372219,Valid MAE:2.06310129 w= [[1.78483891] [-0.941808105]] b= [1.89865637] ================================================================================15:40:34 Epoch=200,Loss:4.18288374,MAE:1.6310848,Valid Loss:3.79517508,Valid MAE:1.53697133 w= [[2.02300119] [-0.992656231]] b= [2.88763976] ================================================================================15:40:42 Epoch=300,Loss:4.17580175,MAE:1.62464666,Valid Loss:3.80199885,Valid MAE:1.53819764 w= [[2.02173] [-0.992035568]] b= [2.97494888] ================================================================================15:40:49 Epoch=400,Loss:4.17601919,MAE:1.6246767,Valid Loss:3.80182695,Valid MAE:1.53820801 w= [[2.02159858] [-0.992003262]] b= [2.97537684]
参考:
开源电子书地址:https://lyhue1991.github.io/eat_tensorflow2_in_30_days/
GitHub 项目地址:https://github.com/lyhue1991/eat_tensorflow2_in_30_days
本文链接:http://task.lmcjl.com/news/12169.html