k-nn gekürzt

This commit is contained in:
paul-loedige 2021-12-02 01:14:54 +01:00
parent 55e02ccc7c
commit 41aac60d78

View File

@ -23,7 +23,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 72,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -81,7 +81,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 73,
"metadata": { "metadata": {
"pycharm": { "pycharm": {
"name": "#%%\n" "name": "#%%\n"
@ -134,7 +134,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 74,
"metadata": { "metadata": {
"pycharm": { "pycharm": {
"name": "#%%\n" "name": "#%%\n"
@ -169,7 +169,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 75,
"metadata": { "metadata": {
"pycharm": { "pycharm": {
"name": "#%%\n" "name": "#%%\n"
@ -293,7 +293,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 76,
"metadata": { "metadata": {
"pycharm": { "pycharm": {
"name": "#%%\n" "name": "#%%\n"
@ -438,7 +438,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 77,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
@ -522,9 +522,15 @@
" :return k-nearest x values [k x input_dimension], k-nearest y values [k x target_dimension]\n", " :return k-nearest x values [k x input_dimension], k-nearest y values [k x target_dimension]\n",
" \"\"\"\n", " \"\"\"\n",
"\n", "\n",
" distance = []\n", " distances = []\n",
" for i in range(x_data.shape[0]):\n", " for i in range(x_data.shape[0]):\n",
" distance.append(np.linalg.norm(x_data[i]-query_point))\n", " distances.append(np.linalg.norm(x_data[i]-query_point))\n",
" \n",
" data = np.column_stack((x_data,y_data,distances))\n",
" data = data[data[:,-1].argsort()]\n",
" nearest_x = data[0:k]\n",
" nearest_y = data[0:k,-2]\n",
" return nearest_x, nearest_y\n",
"# point = np.array([x_data[i], y_data[i]])\n", "# point = np.array([x_data[i], y_data[i]])\n",
"# dist = 0\n", "# dist = 0\n",
"# for j in range(len(point[0])):\n", "# for j in range(len(point[0])):\n",
@ -537,15 +543,15 @@
" #idx = np.argpartition(distance, k)\n", " #idx = np.argpartition(distance, k)\n",
" #idx = idx[:k]\n", " #idx = idx[:k]\n",
"\n", "\n",
" idx = sorted(range(len(distance)), key=lambda i: distance[i])[:k]\n", "# idx = sorted(range(len(distance)), key=lambda i: distance[i])[:k]\n",
"\n", "#\n",
" nearest_x = []\n", "# nearest_x = []\n",
" nearest_y = []\n", "# nearest_y = []\n",
" for j in range(len(idx)):\n", "# for j in range(len(idx)):\n",
" nearest_x.append(x_data[idx[j]])\n", "# nearest_x.append(x_data[idx[j]])\n",
" nearest_y.append(y_data[idx[j]])\n", "# nearest_y.append(y_data[idx[j]])\n",
"\n", "#\n",
" return nearest_x, nearest_y\n", "# return nearest_x, nearest_y\n",
"\n", "\n",
"\n", "\n",
"\n", "\n",
@ -566,17 +572,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 69, "execution_count": 97,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 1.0\n"
]
}
],
"source": [ "source": [
"k = 5\n", "k = 5\n",
"predictions = np.zeros(test_features.shape[0])\n", "predictions = np.zeros(test_features.shape[0])\n",
@ -1321,7 +1319,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 68, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {