Subversion Repositories ESP8266_P1_Meter

Rev

Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
2 raymond 1
// SPDX-License-Identifier: LGPL-3.0-or-later
2
// Copyright 2016-2026 Hristo Gochkov, Mathieu Carbou, Emil Muratov, Will Miles
3
 
4
#include "WebAuthentication.h"
5
#include <ESPAsyncWebServer.h>
6
 
7
#include <list>
8
 
9
AsyncMiddlewareChain::~AsyncMiddlewareChain() {
10
  for (AsyncMiddleware *m : _middlewares) {
11
    if (m->_freeOnRemoval) {
12
      delete m;
13
    }
14
  }
15
}
16
 
17
void AsyncMiddlewareChain::addMiddleware(ArMiddlewareCallback fn) {
18
  AsyncMiddlewareFunction *m = new AsyncMiddlewareFunction(fn);
19
  m->_freeOnRemoval = true;
20
  _middlewares.emplace_back(m);
21
}
22
 
23
void AsyncMiddlewareChain::addMiddleware(AsyncMiddleware *middleware) {
24
  if (middleware) {
25
    _middlewares.emplace_back(middleware);
26
  }
27
}
28
 
29
void AsyncMiddlewareChain::addMiddlewares(std::vector<AsyncMiddleware *> middlewares) {
30
  for (AsyncMiddleware *m : middlewares) {
31
    addMiddleware(m);
32
  }
33
}
34
 
35
bool AsyncMiddlewareChain::removeMiddleware(AsyncMiddleware *middleware) {
36
  // remove all middlewares from _middlewares vector being equal to middleware, delete them having _freeOnRemoval flag to true and resize the vector.
37
  const size_t size = _middlewares.size();
38
  _middlewares.erase(
39
    std::remove_if(
40
      _middlewares.begin(), _middlewares.end(),
41
      [middleware](AsyncMiddleware *m) {
42
        if (m == middleware) {
43
          if (m->_freeOnRemoval) {
44
            delete m;
45
          }
46
          return true;
47
        }
48
        return false;
49
      }
50
    ),
51
    _middlewares.end()
52
  );
53
  return size != _middlewares.size();
54
}
55
 
56
void AsyncMiddlewareChain::_runChain(AsyncWebServerRequest *request, ArMiddlewareNext finalizer) {
57
  if (!_middlewares.size()) {
58
    return finalizer();
59
  }
60
  ArMiddlewareNext next;
61
  std::list<AsyncMiddleware *>::iterator it = _middlewares.begin();
62
  next = [this, &next, &it, request, finalizer]() {
63
    if (it == _middlewares.end()) {
64
      return finalizer();
65
    }
66
    AsyncMiddleware *m = *it;
67
    it++;
68
    return m->run(request, next);
69
  };
70
  return next();
71
}
72
 
73
void AsyncAuthenticationMiddleware::setUsername(const char *username) {
74
  _username = username;
75
  _hasCreds = _username.length() && _credentials.length();
76
}
77
 
78
void AsyncAuthenticationMiddleware::setPassword(const char *password) {
79
  _credentials = password;
80
  _hash = false;
81
  _hasCreds = _username.length() && _credentials.length();
82
}
83
 
84
void AsyncAuthenticationMiddleware::setPasswordHash(const char *hash) {
85
  _credentials = hash;
86
  _hash = _credentials.length();
87
  _hasCreds = _username.length() && _credentials.length();
88
}
89
 
90
void AsyncAuthenticationMiddleware::setToken(const char *token) {
91
  _credentials = token;
92
  _hash = _credentials.length();
93
  _hasCreds = _credentials.length();
94
}
95
 
96
bool AsyncAuthenticationMiddleware::generateHash() {
97
  // ensure we have all the necessary data
98
  if (!_hasCreds) {
99
    return false;
100
  }
101
 
102
  // if we already have a hash, do nothing
103
  if (_hash) {
104
    return false;
105
  }
106
 
107
  switch (_authMethod) {
108
    case AsyncAuthType::AUTH_DIGEST:
109
      _credentials = generateDigestHash(_username.c_str(), _credentials.c_str(), _realm.c_str());
110
      if (_credentials.length()) {
111
        _hash = true;
112
        return true;
113
      } else {
114
        return false;
115
      }
116
 
117
    case AsyncAuthType::AUTH_BASIC:
118
      _credentials = generateBasicHash(_username.c_str(), _credentials.c_str());
119
      if (_credentials.length()) {
120
        _hash = true;
121
        return true;
122
      } else {
123
        return false;
124
      }
125
 
126
    default: return false;
127
  }
128
}
129
 
130
bool AsyncAuthenticationMiddleware::allowed(AsyncWebServerRequest *request) const {
131
  switch (_authMethod) {
132
    case AsyncAuthType::AUTH_NONE:   return true;
133
    case AsyncAuthType::AUTH_DENIED: return false;
134
    case AsyncAuthType::AUTH_BEARER: return _authcFunc(request);
135
    case AsyncAuthType::AUTH_OTHER:  return _authcFunc(request);
136
    case AsyncAuthType::AUTH_BASIC:  return !_hasCreds || _authcFunc(request);
137
    case AsyncAuthType::AUTH_DIGEST: return !_hasCreds || _authcFunc(request);
138
    default:                         return false;
139
  }
140
}
141
 
142
void AsyncAuthenticationMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
143
  return allowed(request) ? next() : request->requestAuthentication(_authMethod, _realm.c_str(), _authFailMsg.c_str());
144
}
145
 
146
void AsyncHeaderFreeMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
147
  std::list<const char *> toRemove;
148
  for (auto &h : request->getHeaders()) {
149
    bool keep = false;
150
    for (const char *k : _toKeep) {
151
      if (strcasecmp(h.name().c_str(), k) == 0) {
152
        keep = true;
153
        break;
154
      }
155
    }
156
    if (!keep) {
157
      toRemove.push_back(h.name().c_str());
158
    }
159
  }
160
  for (const char *h : toRemove) {
161
    request->removeHeader(h);
162
  }
163
  next();
164
}
165
 
166
void AsyncHeaderFilterMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
167
  for (auto it = _toRemove.begin(); it != _toRemove.end(); ++it) {
168
    request->removeHeader(*it);
169
  }
170
  next();
171
}
172
 
173
void AsyncLoggingMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
174
  if (!isEnabled()) {
175
    next();
176
    return;
177
  }
178
  _out->print(F("* Connection from "));
179
#ifndef LIBRETINY
180
  _out->print(request->client()->remoteIP().toString());
181
#else
182
  _out->print(request->client()->remoteIP());
183
#endif
184
  _out->print(':');
185
  _out->println(request->client()->remotePort());
186
  _out->print('>');
187
  _out->print(' ');
188
  _out->print(request->methodToString());
189
  _out->print(' ');
190
  _out->print(request->url().c_str());
191
  _out->print(F(" HTTP/1."));
192
  _out->println(request->version());
193
  for (auto &h : request->getHeaders()) {
194
    if (h.value().length()) {
195
      _out->print('>');
196
      _out->print(' ');
197
      _out->print(h.name());
198
      _out->print(':');
199
      _out->print(' ');
200
      _out->println(h.value());
201
    }
202
  }
203
  _out->println(F(">"));
204
  uint32_t elapsed = millis();
205
  next();
206
  elapsed = millis() - elapsed;
207
  AsyncWebServerResponse *response = request->getResponse();
208
  if (response) {
209
    _out->print(F("* Processed in "));
210
    _out->print(elapsed);
211
    _out->println(F(" ms"));
212
    _out->print('<');
213
    _out->print(F(" HTTP/1."));
214
    _out->print(request->version());
215
    _out->print(' ');
216
    _out->print(response->code());
217
    _out->print(' ');
218
    _out->println(AsyncWebServerResponse::responseCodeToString(response->code()));
219
    for (auto &h : response->getHeaders()) {
220
      if (h.value().length()) {
221
        _out->print('<');
222
        _out->print(' ');
223
        _out->print(h.name());
224
        _out->print(':');
225
        _out->print(' ');
226
        _out->println(h.value());
227
      }
228
    }
229
    _out->println('<');
230
  } else {
231
    _out->println(F("* Connection closed!"));
232
  }
233
}
234
 
235
void AsyncCorsMiddleware::addCORSHeaders(AsyncWebServerRequest *request, AsyncWebServerResponse *response) {
236
  if (request != nullptr && _credentials && _origin == "*") {
237
    // cannot use wildcard when allowing credentials
238
    response->addHeader(asyncsrv::T_CORS_ACAO, request->header(asyncsrv::T_CORS_O).c_str());
239
  } else {
240
    response->addHeader(asyncsrv::T_CORS_ACAO, _origin.c_str());
241
  }
242
  response->addHeader(asyncsrv::T_CORS_ACAM, _methods.c_str());
243
  response->addHeader(asyncsrv::T_CORS_ACAH, _headers.c_str());
244
  response->addHeader(asyncsrv::T_CORS_ACAC, _credentials ? asyncsrv::T_TRUE : asyncsrv::T_FALSE);
245
  response->addHeader(asyncsrv::T_CORS_ACMA, String(_maxAge).c_str());
246
}
247
 
248
void AsyncCorsMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
249
  // Origin header ? => CORS handling
250
  if (request->hasHeader(asyncsrv::T_CORS_O)) {
251
    // check if this is a preflight request => handle it and return
252
    if (request->method() == AsyncWebRequestMethod::HTTP_OPTIONS) {
253
      AsyncWebServerResponse *response = request->beginResponse(200);
254
      addCORSHeaders(request, response);
255
      request->send(response);
256
      return;
257
    }
258
 
259
    // CORS request, no options => let the request pass and add CORS headers after
260
    next();
261
    AsyncWebServerResponse *response = request->getResponse();
262
    if (response) {
263
      addCORSHeaders(request, response);
264
    }
265
 
266
  } else {
267
    // NO Origin header => no CORS handling
268
    next();
269
  }
270
}
271
 
272
bool AsyncRateLimitMiddleware::isRequestAllowed(uint32_t &retryAfterSeconds) {
273
  uint32_t now = millis();
274
 
275
  while (!_requestTimes.empty() && _requestTimes.front() <= now - _windowSizeMillis) {
276
    _requestTimes.pop_front();
277
  }
278
 
279
  _requestTimes.push_back(now);
280
 
281
  if (_requestTimes.size() > _maxRequests) {
282
    _requestTimes.pop_front();
283
    retryAfterSeconds = (_windowSizeMillis - (now - _requestTimes.front())) / 1000 + 1;
284
    return false;
285
  }
286
 
287
  retryAfterSeconds = 0;
288
  return true;
289
}
290
 
291
void AsyncRateLimitMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
292
  uint32_t retryAfterSeconds;
293
  if (isRequestAllowed(retryAfterSeconds)) {
294
    next();
295
  } else {
296
    AsyncWebServerResponse *response = request->beginResponse(429);
297
    response->addHeader(asyncsrv::T_retry_after, retryAfterSeconds);
298
    request->send(response);
299
  }
300
}